Spatial join large dask dataframe with small dataframe

What is the best way to perform this merge? My dask dataframe is a large Twitter dataset with has ~20,000 partitions. The pandas dataframe with FIPS codes is 100 KB. Up untill I call the “add_fips” method, the pipeline runs smoothly and quickly (in just a few seconds). However, when I call “add_fips”, I get a message that says:

UserWarning: Sending large graph of size 335.67 MiB.
This may cause some slowdown.
Consider loading the data with Dask directly or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.

and the pipeline seems to never terminate. What is the correct way of doing this merge? Here is my code:

def main():
    cluster = LocalCluster(n_workers=64, threads_per_worker=1, memory_limit='4 GiB')
    client = Client(cluster)
    
    gdf_counties = build_counties(INDIR_GEOGRAPHY_COUNTIES)
    gdf_places = build_places(INDIR_GEOGRAPHY_PLACES)
    
    tweets, status_log = build_tweets(INDIR_TWEETS, YEARS, SCHEMA)
    tweets_clean = tweets.map_partitions(clean_tweets)
    tweets_with_fips = tweets_clean.map_partitions(add_fips, gdf_counties, gdf_places, meta=tweets_clean._meta.assign(fips_county=0.0, fips_place=0.0))
    tweets_with_fips.count().compute()
    

def build_tweets(INDIR, YEARS, SCHEMA):
    files = [f for year in YEARS for f in glob.glob(str(INDIR / f"{year}" / "*.csv.gz"))]
    delayed_results = [read_file(file, SCHEMA) for file in files]
    dfs = [result[0] for result in delayed_results]
    logs = [result[1] for result in delayed_results]
    tweets = dd.from_delayed(dfs, meta=SCHEMA)
    log = pd.DataFrame(compute(*logs))
    return tweets, log

@delayed
def read_file(filename, SCHEMA):
    is_fail = 0
    error_msg = None
    try:
        dtypes = {k: v for k, v in SCHEMA.items() if k != 'date'}
        df = pd.read_csv(
            filename,
            engine="c",
            compression="gzip",
            sep="\t",
            low_memory=True,
            quotechar='"',
            lineterminator='\n',
            on_bad_lines='warn',
            parse_dates=['date'],
            date_format="%Y-%m-%d %H:%M:%S",
            dtype=dtypes
        )
    except Exception as e:
        df = pd.DataFrame({col: pd.Series(dtype=dtype) for col, dtype in SCHEMA.items()})
        is_fail = 1
        error_msg = str(e)
    log = {"file": filename, "failure": is_fail, "error": error_msg}
    return df, log


def build_counties(INDIR_COUNTIES):
    counties = gpd.read_file(INDIR_COUNTIES / "counties_2024.shp")
    
    if counties is None:
        return None
    
    counties_clean = (
        counties
        .assign(fips_state = lambda x: x["STATEFP"].astype(float))
        .assign(fips_county = lambda x: x["GEOID"].astype(float))
        .query("fips_state <= 56")
        [['geometry', 'fips_county']]
        .to_crs("EPSG:4326")
    )

    return counties_clean

def build_places(INDIR_PLACES):
    shapefiles = INDIR_PLACES.rglob("*.shp")
    gdfs = []
    for shp in shapefiles:
        places_shapefile = gpd.read_file(shp).to_crs("EPSG:4326")
        gdfs.append(places_shapefile)
    
    places = pd.concat(gdfs, ignore_index=True)
    places_clean = (
        places
        .assign(fips_state = lambda x: x["STATEFP"].astype(float))
        .assign(fips_place = lambda x: x["GEOID"].astype(float))
        .query("fips_state <= 56")
        [['geometry', 'fips_place']]
    )
    return places_clean


def add_fips(df_tweets, gdf_counties, gdf_places):
    gdf_tweets = gpd.GeoDataFrame(df_tweets, geometry=gpd.points_from_xy(df_tweets['longitude'], df_tweets['latitude']), crs='epsg:4326')
    gdf_tweets_merged = (
        gdf_tweets
        .sjoin(gdf_counties, how='inner', predicate='within')
        .drop(columns=['index_right'])
        .sjoin(gdf_places, how='left', predicate='within')
        .drop(columns=['index_right'])
    )
    return gdf_tweets_merged

Do you mean that if you comment the tweets_clean.map_partitions(add_fips ... line, the code runs fine?

Maybe you should try this suggestion and add something like:

gdf_counties_delayed = delayed(gdf_counties)
gdf_places_delayed=delayed(gdf_places)

...
tweets_with_fips = tweets_clean.map_partitions(add_fips, gdf_counties_delayed, gdf_places_delayed, meta=tweets_clean._meta.assign(fips_county=0.0, fips_place=0.0))