Implementing custom lambda function in Dask

I have written a custom lambda function which needs to be applied after groupby operation in the dataframe. The lambda function concatenates all **unique **strings together in a certain group with an appropriate joiner such as ", ". I am trying to implement the same functionality in Dask library of Python. However, I am getting the error message shown below. Could anybody guide me on how to implement this lambda function in Dask?

Implemented in Pandas:

A = pd.DataFrame(data = {"A": ["saad", "saad", "saad", "saad", "nimra", "asad", "nimra", "nimra", "asad"],
                         "B": ["hello", "hello", "saad", "whatsup?", "yup", "nup", "saad", "saad", "nup"],
                         "C": ["hello", "hello", "saad", "whatsup?", "yup", "nup", "saad", "saad", "nup"]
                        }
                )
A.groupby("A")["B"].unique().apply(', '. join)
A.groupby("A").agg(lambda s: ', '.join(s.unique()))

This code works perfectly fine and produces the correct output:
              
A        B                                      C
asad    nup                                    nup
nimra   yup, saad                          yup, saad
saad    hello, saad, whatsup?       hello, saad, whatsup?

Implemented in Dask
I tried to implement it in Dask using this code:

A_1 = A.copy()
A_1 = dd.from_pandas(A, npartitions=2)
A_1.groupby("A").agg(lambda s: ', '.join(s.unique()))

However, the following error occurs:
ValueError                                Traceback (most recent call last)
Cell In[20], line 1
----> 1 A_1.groupby("A").agg(lambda s: ', '.join(s.unique()))

File /opt/miniconda/lib/python3.8/site-packages/dask/dataframe/groupby.py:369, in numeric_only_not_implemented.<locals>.wrapper(self, *args, **kwargs)
    359             if (
    360                 PANDAS_GT_150
    361                 and not PANDAS_GT_200
    362                 and numeric_only is no_default
    363             ):
    364                 warnings.warn(
    365                     "The default value of numeric_only will be changed to False "
    366                     "in the future when using dask with pandas 2.0",
    367                     FutureWarning,
    368                 )
--> 369 return func(self, *args, **kwargs)

File /opt/miniconda/lib/python3.8/site-packages/dask/dataframe/groupby.py:2832, in DataFrameGroupBy.agg(self, arg, split_every, split_out, shuffle, **kwargs)
   2829 @_aggregate_docstring(based_on="pd.core.groupby.DataFrameGroupBy.agg")
   2830 @numeric_only_not_implemented
   2831 def agg(self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs):
-> 2832     return self.aggregate(
   2833         arg=arg,
   2834         split_every=split_every,
   2835         split_out=split_out,
   2836         shuffle=shuffle,
   2837         **kwargs,
   2838     )

File /opt/miniconda/lib/python3.8/site-packages/dask/dataframe/groupby.py:2821, in DataFrameGroupBy.aggregate(self, arg, split_every, split_out, shuffle, **kwargs)
   2818 if arg == "size":
   2819     return self.size()
-> 2821 return super().aggregate(
   2822     arg=arg,
   2823     split_every=split_every,
   2824     split_out=split_out,
   2825     shuffle=shuffle,
   2826     **kwargs,
   2827 )

File /opt/miniconda/lib/python3.8/site-packages/dask/dataframe/groupby.py:2248, in _GroupBy.aggregate(self, arg, split_every, split_out, shuffle, **kwargs)
   2245 else:
   2246     raise ValueError(f"aggregate on unknown object {self.obj}")
-> 2248 chunk_funcs, aggregate_funcs, finalizers = _build_agg_args(spec)
   2250 if isinstance(self.by, (tuple, list)) and len(self.by) > 1:
   2251     levels = list(range(len(self.by)))

File /opt/miniconda/lib/python3.8/site-packages/dask/dataframe/groupby.py:951, in _build_agg_args(spec)
    948 if not isinstance(func, Aggregation):
    949     func = funcname(known_np_funcs.get(func, func))
--> 951 impls = _build_agg_args_single(
    952     result_column, func, func_args, func_kwargs, input_column
    953 )
    955 # overwrite existing result-columns, generate intermediates only once
    956 for spec in impls["chunk_funcs"]:

File /opt/miniconda/lib/python3.8/site-packages/dask/dataframe/groupby.py:1010, in _build_agg_args_single(result_column, func, func_args, func_kwargs, input_column)
   1007     return _build_agg_args_custom(result_column, func, input_column)
   1009 else:
-> 1010     raise ValueError(f"unknown aggregate {func}")

ValueError: unknown aggregate lambda
A.groupby("A").apply(','.join)

Hi @smunir1994, welcome here!

As Dask works in distributed, and will use some Map Reduce logic to apply your aggregation, the code to write is a little bit more complex. You’ve got to use
https://docs.dask.org/en/stable/generated/dask.dataframe.groupby.Aggregation.html

Here is some code that works:

custom_dask_agg = dd.Aggregation(
    name='custom_dask_agg',
    chunk=lambda x: x.agg(lambda s: ', '.join(s.unique())),
    agg=lambda y: y.agg(lambda s: ', '.join(s.unique()))
)
A_1.groupby("A").agg(custom_dask_agg).compute()