Hi,
I am trying to sum the elements of a list column in a Dask Dataframe column.
import pandas as pd
import dask.dataframe as dd
# Create a Dask DataFrame
data = {'A': [[1, 2], [3, 4], [5, 6]]}
df = dd.from_pandas(pd.DataFrame(data), npartitions=1)
# Sum the list entries in column 'A'
df['sum'] = df['A'].apply(sum, meta=('sum', 'int64')) # meta is required for Dask
# Compute the result
result = df.compute()
print(result)
However, I am always getting the following error message
TypeError: unsupported operand type(s) for +: âintâ and âstrâ
In Pandas, the same operation can be done very easily with no errors
import pandas as pd
# Create a Dask DataFrame
data = {'A': [[1, 2], [3, 4], [5, 6]]}
df = pd.DataFrame(data)
# Sum the list entries in column 'A'
df['sum'] = df['A'].apply(sum)
Do you have any ideas or suggestions on how to fix this in Dask?
Thanks
You might be interested to look at akimbo
, which is designed exactly for complex data types stored in normal dataframes. The code to get you want would be
import akimbo.dask
df.A.ak.sum(axis=-1)
Note, that this will not be particularly faster than map with sum, or delayed function with iteration/comprehensions, if your data starts off as python objects. BUT, if your data comes from a source which can directly make arrow vectorized data (e.g., parquet), then there is no conversion, and akimbo will tend to be much faster than pure python.
The origin of your exception, is that the data type is âobjectâ in the original data, which dask interprets to mean âstringâ (which is indeed the most common case), so it is silently converting your lists to strings. I donât know how to fix that, except to explicitly convert the data to arrow first (for which I would also use akimbo
)
data = {"A": [[1, 2], [3, 4], [5, 6]] * 100000} # bigger data for benchmarking
df = dd.from_pandas(pd.DataFrame(data).ak.unpack(), npartitions=8)
In [49]: %timeit df.map(sum, meta="int64").compute()
370 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [52]: %timeit df.A.ak.sum(axis=-1).compute()
8.27 ms ± 85.3 ”s per loop (mean ± std. dev. of 7 runs, 100 loops each)
1 Like