How to avoid shuffling when doing groupby apply?

Hi friends, I want to make my group-by apply more efficient by making sure the values from the same group only sit in the same partition. However, when I test the idea in the following program. It seems there’s still shuffling in the graph. I understand Dask has no idea about the fact that Im preparing the data in such a way. But is there a way to let it know? (I cannot set the groupby columns as indexes because multiindex is not allowed)

import dask
import pandas as pd
from dask.dataframe import from_pandas
from distributed import Client, LocalCluster


def tap(df):
    print(df)
    return df


def main():
    with dask.config.set({'distributed.scheduler.allowed-failures': 0, "distributed.logging.distributed": "DEBUG"}):
        with LocalCluster(n_workers=1, threads_per_worker=1, memory_limit="200MiB") as cluster, Client(cluster) as client:
            df = pd.DataFrame(dict(a=list('xxyyzz'),
                                   c=[datetime.datetime(2010, 1, 1),
                                      datetime.datetime(2010, 1, 1),
                                      datetime.datetime(2010, 1, 1),
                                      datetime.datetime(2010, 2, 1),
                                      datetime.datetime(2010, 2, 1),
                                      datetime.datetime(2010, 2, 1)],
                                   d=[1, 2, 3, 4, 5, 6],
                                   ))

            print("------pandas------")
            print(df)
            ddf = from_pandas(df, npartitions=3)
            print("------partitions------")
            ddf.map_partitions(tap, meta=ddf).compute(scheduler=client)

            print("------group by------")
            group = ddf.groupby('a', 'c').apply(tap, meta=ddf)
            group.visualize("chart.svg")
            group.compute(scheduler=client)


if __name__ == "__main__":
    main()

outputs

------partitions------
   a          c  d
0  x 2010-01-01  1
1  x 2010-01-01  2
   a          c  d
2  y 2010-01-01  3
3  y 2010-02-01  4
   a          c  d
4  z 2010-02-01  5
5  z 2010-02-01  6
------group by------
   a          c  d
2  y 2010-01-01  3
3  y 2010-02-01  4
   a          c  d
0  x 2010-01-01  1
1  x 2010-01-01  2
   a          c  d
4  z 2010-02-01  5
5  z 2010-02-01  6

I updated the test program to give a more well-rounded example.

import datetime

import dask
import pandas as pd
from dask.dataframe import from_pandas
from distributed import Client, LocalCluster

def tap(df):
    print(df)
    return df


def main():
    with dask.config.set({'distributed.scheduler.allowed-failures': 0, "distributed.logging.distributed": "DEBUG"}):
        with LocalCluster(n_workers=1, threads_per_worker=1, memory_limit="200MiB") as cluster, Client(cluster) as client:
            df = pd.DataFrame(dict(a=list('yyyyyyyyxxzz'),
                                   c=[datetime.datetime(2010, 3, 1),
                                      datetime.datetime(2010, 3, 1),
                                      datetime.datetime(2010, 3, 1),
                                      datetime.datetime(2010, 3, 1),
                                      datetime.datetime(2010, 3, 1),
                                      datetime.datetime(2010, 3, 1),
                                      datetime.datetime(2010, 1, 1),
                                      datetime.datetime(2010, 1, 1),
                                      datetime.datetime(2010, 1, 1),
                                      datetime.datetime(2010, 2, 1),
                                      datetime.datetime(2010, 2, 1),
                                      datetime.datetime(2010, 2, 1),
                                      ],
                                   d=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
                                   ))

            print("------pandas------")
            print(df)
            ddf = from_pandas(df, npartitions=2)
            print("------partitions------")
            ddf.map_partitions(tap, meta=ddf).compute(scheduler=client)

            print("------group by------")
            group = ddf.groupby(by=['a', 'c']).apply(tap, meta=ddf)
            # group = perfect_apply(ddf.groupby(by=['a', 'c']), tap, meta=ddf)
            group.visualize("chart.svg")
            group.compute(scheduler=client)

if __name__ == "__main__":
    main()

So even though the only group in partition 1 is (y, 2010-3-1), and partition 2 has the others ((y, 2010-1-1),(x, 2010-1-1), (x, 2010-2-1), (z, 2010-1-1)), Dask still includes the shuffle steps as shown in the screenshot. I did manage to work around that using set_index() but it has other issue which I’ll explain at the end.

I ended up monkey patching the apply function and took away the should_shuffle, if you want to try it out, add the following definition and use the group = perfect_apply(ddf.groupby(by=['a', 'c']), tap, meta=ddf line instead.

from dask.dataframe import map_partitions
from dask.dataframe.dispatch import make_meta
from dask.dataframe.groupby import _groupby_slice_apply
from dask.utils import funcname


def perfect_apply(groupby_object, func, meta, *args, **kwargs):
    meta = make_meta(meta, parent_meta=groupby_object._meta.obj)

    kwargs["meta"] = meta
    return map_partitions(
        _groupby_slice_apply,
        groupby_object.obj,
        groupby_object.by,
        groupby_object._slice,
        func,
        token=funcname(func),
        *args,
        group_keys=groupby_object.group_keys,
        **groupby_object.observed,
        **groupby_object.dropna,
        **kwargs,
    )

this does the trick and still getting the correct output. I’m wondering if there’s any case this will yield incorrect output?
Screen Shot 2022-04-21 at 5.15.29 PM

About another option:
After reading the code, I realize as long as the index col is ONE of the groupby cols, the shuffling can be avoid. However, set_index will order the data by the index col(either pre-sorted that way or involves some computation). The change of order is not ideal because it will likely break the balance of number of rows in my existing partitions. That’s why I didn’t go down this path.

Any comment is appreciated!

@ubw218 Sorry for the delay in response.

I understand Dask has no idea about the fact that Im preparing the data in such a way. But is there a way to let it know?

Not that I’m aware of. I see you’ve used map_partitions already, and that’s what I’d suggest too. If you can create partitions based on groups before-hand, map_partitons would be your best option.

Your perfect_apply function is an interesting idea (and, thanks for sharing it!), but it uses Dask’s internal functions which we do not generally recommend doing.

this does the trick and still getting the correct output. I’m wondering if there’s any case this will yield incorrect output?

I can’t think of a case right now, but I wouldn’t rely on this to work flawlessly. :smile: