Writing a sort in Dask

Hi Dask community! I’m doing some evaluation of different distributed programming models and as part of that I’m trying to implement a distributed radix sort with a bunch of different distributed programming tools as a way of comparing their capabilities. I’m new to Dask but I’m not new to parallel computing.

First question: is writing a sort in Dask / Dask Distributed a reasonable thing to do? If it is, anybody know of an existing implementation I can look at? If it is not, why not?

Second question: Is my approach (described below) a reasonable way to use Dask or is there some other way that I should think about the partitioning steps in the radix sort?

Third question: Any idea what is going wrong with my program, which seems to use up memory and hang?

Here is the approach I have been taking: To avoid mutating the data (as it’s my understanding that the elements of Dask arrays are immutable) I settled on the approach of implementing the partition part of the radix sort as a function that can apply (in parallel) to each chunk of a Dask array and return a list of arrays consisting of the data items sorted by the current digit (so, subarrays[0] would be the portion of the chunk’s elements that start with 0 in that digit). I can’t just sort the chunks independently; that would give sorted chunks but not the sorted data overall. So I am trying to sortof transpose these lists of arrays (so that the final result has the digit 0 subarrays for all the chunks, then the digit 1 subarrays for all the chunks, etc). Note that this is not a normal transpose because each of these subarrays can have a different length.

I was running into a lot of challenge getting map_blocks to work with my partition function that returns a list but I found Use map_blocks with function that returns a tuple - #7 by ParticularMiner and was able to get something seems to runs, at least, at first.

The problem I am facing now is that this program seems to hang after processing a couple of digits. I think it’s doing too much work in the original Python intepreter (rather than the Dask worker processes) but it’s not obvious to me what the problem is. Maybe Dask is creating task graphs for the full computation and getting hung up? I was trying to avoid that by using persist.

Any help is appreciated, thanks.

Here is the full program if anybody wants to try it (although, please be aware it will probably try to use more and more memory, so be prepared to kill it if you try it).

# this version seems to hang and give error
# messages about large task graphs.

# it uses the strategy described in 
# https://dask.discourse.group/t/use-map-blocks-with-function-that-returns-a-tuple/84/7
# to have 'partition' return a list.

import argparse
import dask
from dask.distributed import Client, wait
import dask.array as da
import numpy as np
import time

radix = 8
n_buckets = 1 << radix
n_digits = 64 // radix
mask = n_buckets - 1
trace = True

def make_structured(x_chunk, block_info=None):
    if trace:
        print("make_structured ", repr(x_chunk), block_info[0])
    ret = np.zeros(x_chunk.size, dtype='u8, u8')
    start = block_info[0]['array-location'][0][0]
    if trace:
        print("start is ", start)
    for elt,rand,i in zip(ret, x_chunk, range(x_chunk.size)):
        elt[0] = rand
        elt[1] = start+i
    return ret

def bkt(x, digit):
    #print("bkt", hex(x[0]), digit)
    ret = (x[0] >> np.uint64((radix*digit))) & np.uint64(mask)
    #print("bkt ret", hex(ret))
    return ret

# now compute the data arrays for each key from each block
def partition(x_chunk, digit):
    # generate an array-of-arrays
    # inner arrays are the data for each bucket
    #if trace:
    #    print("partition ", x_chunk, digit)
    digit = np.uint64(digit)
    # count the number in each bucket
    counts = np.zeros(n_buckets, dtype='u8')
    for x in x_chunk:
        counts[bkt(x, digit)] += 1
    # allocate the subarrays
    subarrays = []
    for c in counts:
        subarrays.append(np.zeros(c, dtype='u8, u8'))
    # store the data into the subarrays
    counts.fill(0)
    for x in x_chunk:
        b = bkt(x, digit)
        subarrays[b][counts[b]] = x
        counts[b] += 1
    # return the subarrays
    #if trace:
    #    print("partition returning ", subarrays)
    return subarrays


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--n", type=int, help="number of records to sort",
                        default=128*1024*1024)
    parser.add_argument("--n-workers", type=int, help="number of workers", default=16)
    parser.add_argument("--chunk_size", type=int, help="chunk size", default=None)
    args = parser.parse_args()

    n_workers=args.n_workers
    client = Client(processes=True, n_workers=n_workers)

    n = args.n
    chunk_size = args.chunk_size
    if not chunk_size:
        chunk_size = n // (n_workers*8)
    chunk_size = max(1, chunk_size)

    # run a local partition just to make sure everything works,
    # before going distributed
    test_rng = np.random.default_rng()
    test_r = test_rng.integers(0, 0xffffffffffffffff, size=10, dtype='u8')
    test_x = make_structured(test_r, [{'array-location': [(0, 10)]}])
    bkt(test_x[0], 0)
    for digit in range(n_digits):
        subarrays = partition(test_x, digit) 
        test_x = np.concatenate(subarrays, axis=0)
        #print("after digit", digit)
        #for x in test_x:
        #    print(hex(x[0]))
    for i in range(test_x.size):
        if i > 0:
            assert(test_x[i-1][0] <= test_x[i][0])

    print("Generating", n, "records of input with ", n_workers,
          "workers and chunk size ", chunk_size)
    start = time.time()

    rng = da.random.default_rng()
    r = rng.integers(0, 0xffffffffffffffff, size=n, dtype='u8',
                     chunks=chunk_size)

    # create the input data, consisting of pairs of 8-byte values
    x = da.map_blocks(make_structured, r, dtype='u8, u8')
    #x = x.persist()
    x = client.persist(x)
    wait(x)
 
    stop = time.time()
    print("Generated input in ", stop-start, " seconds")

    if trace:
        print("generated input is ", x.compute())

    print("Sorting", n, "records with", n_workers,
          "workers and chunk size", chunk_size,
          "and radix ", radix, "(", n_buckets, " buckets )")
    start = time.time()

    meta = np.array([], dtype=x.dtype)

    for digit in range(n_digits):
        print("digit", digit)

        def partition_by_digit(x_chunk):
            return partition(x_chunk, digit)

        # create an array where chunks are lists; each list contains
        # the subarrays starting with that digit 
        list_arr = x.map_blocks(partition_by_digit, dtype=x.dtype, meta=meta)
        # so that the below calls do not recompute the partition
        list_arr = client.persist(list_arr)
        #list_arr.persist()

        to_concat = [ ]
        for d in range(n_buckets):
            def get_ith(x):
                return x[d]

            to_concat.append(list_arr.map_blocks(get_ith, dtype=x.dtype, meta=meta))

        x = da.concatenate(to_concat, axis=0).rechunk(chunk_size)
        x = client.persist(x)

    print("Waiting for work")
    wait(x)
    
    stop = time.time()
    print("Sorted in ", stop-start, " seconds")

    if trace:
       tmp = s.compute()
       for x in tmp:
           print(hex(x[0]), x[1])

    exit()

Hi @mppf, welcome to Dask community!

This answer is easy, yes, you can write a distributed sort with Dask. There is no Dask Array implementation, but one is written for DataFrames.

Well, the code you provided is a bit hard to read, and I’m not sure I follow the approach. Do you think you could simplify the code a bit or try to describe your approach more in detail? Your correct that you should not mutate the inputs.

I see you are using persist a lot, when every persisted array seems to be about 1GB in size? How much memory do your have? Have you taken a look at the Dashboard?

Using persist won’t prevent Dask from creating a complete graph, but it will keep some results into memory.

Hi & thanks for your reply.

This answer is easy, yes, you can write a distributed sort with Dask.

Right, I’m aware of the DataFrame implementation, which demonstrates it’s possible to write a sort in Dask. That said I’m not sure I’ve found its implementation yet & appreciate pointers to the implementation or descriptions of it, if you know where to look. At present, my understanding is that the DataFrame sort relies on some of the details about how DataFrames are represented, which makes me think that perhaps it is not reasonable to try to write a general sort in Dask (using Dask arrays, say). Additionally, it might not be reasonable to write a radix sort in Dask; maybe Dask is only well suited to other kinds of sort algorithms.

Well, the code you provided is a bit hard to read, and I’m not sure I follow the approach. Do you think you could simplify the code a bit or try to describe your approach more in detail?

I’m trying to write an LSB radix sort. It’s a sort algorithm that works by sorting by the rightmost digit of each number, then the next, then the next, to get everything sorted. You can see an example in Radix sort - Wikipedia. Let’s consider the first part of that example so I can tie it back to what I’m trying to do with Dask.

Input

[170, 45, 75, 90, 2, 802, 2, 66]

Sorted by last digit

[{170, 90}, {2, 802, 2}, {45, 75}, {66}]

Here the { } regions are the groups that end with that particular digit.
The partition function in my program generates these same groupings but it operates on a chunk at a time (so as to be parallel).

Let’s imagine a similar example with the data divided into 2 chunks:

Input
[170, 45, 75, 90] [2, 800, 2, 65]

Chunks sorted by last digit
{0: [170, 90], 5: [45, 75]} {0: [800], 2: [2, 2], 5: [65]}

Now comes the part that is sortof a transpose (implemented with to_concat in the code); we need to get these subarrays sorted by digit and then chunk:

Data sorted by last digit
[170, 90] # digit 0 from chunk 0
[800] # digit 0 from chunk 1
# ...
[ ] # digit 2 from chunk 0
[2, 2] # digit 2 from chunk 1
# ...
[45, 75] # digit 5 from chunk 0
[65] # digit 5 from chunk 1

The final step is to concatenate these arrays (and then continue with the alogrithm, with that array as the input, to sort by the next digit, until we are out of digits).

I see you are using persist a lot, when every persisted array seems to be about 1GB in size? How much memory do your have? Have you taken a look at the Dashboard?

I’ve been experimenting with this on a system with 64 GB of memory. I’ve been watching the process run with top and I see the memory usage of the original python interpreter growing (seemingly without bound) while the workers use memory briefly and then free it. I have tried the Dashboard which works OK until the program prints the large graph warning (see below) and digit 1; after which point everything in the Dashboard hangs.

Here’s the output from running the program:

Sorting 134217728 records with 16 workers and chunk size 1048576 and radix  8 ( 256  buckets )
digit 0
/home/mppf/explore/dask/venv/lib/python3.12/site-packages/distributed/client.py:3370: UserWarning: Sending large graph of size 11.34 MiB.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
  warnings.warn(
digit 1

followed by errors along the lines of 2025-01-27 12:36:59,425 - distributed.worker - ERROR - Failed to communicate with scheduler during heartbeat. and other problems indicating disconnection.

The warning, combined with the memory behavior in top, makes me think that my program is causing Dask to create an overly complex task graph. What I would like it to do is to save the partially sorted array in memory in the workers so that the task graph would be similar to the 1st pass.

Using persist won’t prevent Dask from creating a complete graph, but it will keep some results into memory.

Is there a way I can ask Dask to save an array in memory in the workers without holding on to the way it was computed (the task graph)? I would imagine that I could save each intermediate array to a file to achieve that, or compute it and save it in memory in the running Python interpreter, but I’d like to do this in a way that keeps everything in memory on the workers.

Thanks again.