CancelledError when running .compute() on DataArray of 5,000+ objects

NOTE: It will be difficult to provide a good reprex for this issue as it requires an API key to an open-source service, as well as knowledge of the STAC ecosystem and the corresponding pystac library. I have done my best to simplify the sample while still retaining key functional elements of the original notebook.

Problem:

I am running into a Dask CancelledError when running the compute function on a collection of delayed objects. The function below wraps the process of fetching a DataArray of satellite imagery based on a STAC item ID. When running the code below on the 3 sample item IDs, there is no error. However, when I am running on a random sample of 5,000 out of 500,000 total STAC items, the error is raised fairly quickly.

Is there a way to catch and log the results of computation on the Dask delayed objects when there is an error, to check if an asset being fetched by the pipeline is failing, or if it’s an issue with Dask itself?

My assumption is the issue is within Dask because I took the exact same wrapper function and mapped it using dask.bag and it did not raise the CancelledError I was running into using dask.delayed. Any thoughts or suggestions would be greatly appreciated.

For practical purposes create_landsat_8_dataarray mentioned in the raised error traceback is interchangeable with fetch_xarray.

Code:

import dask
import xarray
import pystac
import pystac_client
import stackstac
import planetary_computer
import datetime

# API endpoints for MLHub and Planetary Computer catalogs
MLHUB_API_URL = "https://api.radiant.earth/mlhub/v1"
MSPC_API_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"
MLHUB_API_KEY = "77588a47fe96febbbf270d89b47743600cb296290238471dde9d6784815ac398" # this is a temporary key created for this forum, the data and registration is already open to the public

# Create a connection to the MLHub STAC API endpoint
MLHUB_CATALOG = pystac_client.Client.open(
    url=MLHUB_API_URL, 
    parameters={"key": MLHUB_API_KEY}
)

# Create a connection to the Microsoft Planetary Computer API endpoint
MSPC_CATALOG = pystac_client.Client.open(MSPC_API_URL)

# Names of Collections that will be queried against using pystac_client
BIGEARTHNET_SOURCE_COLLECTION = "bigearthnet_v1_source"  # sentinel-2 source imagery
BIGEARTHNET_LABEL_COLLECTION = "bigearthnet_v1_labels"  # geojson classification labels
PLANETARY_COMPUTER_LANDSAT_8 = "landsat-8-c2-l2"  # landsat 8 source imagery on PC
DATE_BUFFER = 60 # date range to add before and after a item datetime

SAMPLE_ITEM_IDS = [
"bigearthnet_v1_source_S2B_MSIL2A_20180511T100029_14_9",
"bigearthnet_v1_source_S2B_MSIL2A_20180421T114349_54_81",
"bigearthnet_v1_source_S2A_MSIL2A_20170613T101031_24_58"
]

def temporal_buffer(item_dt: str, date_delta: int) -> str:
    """Takes a datetime object and returns a buffer around that date

    Args:
        item_datetime: datetime property from an Item
        date_delta: integer for days to add before and after a date

    Returns:
        a string range representing the full date buffer
    """
    delta = td(days=date_delta)
    # item_dt = datetime.strptime(item_datetime, "%Y-%m-%dT%H:%M:%SZ")

    dt_start = item_dt - delta
    dt_start_str = dt_start.strftime("%Y-%m-%d")

    dt_end = item_dt + delta
    dt_end_str = dt_end.strftime("%Y-%m-%d")

    return f"{dt_start_str}/{dt_end_str}"

def fetch_xarray(item_id: str) -> xarray.DataArray:
    """Takes the sring STAC Item ID to return a DataArray of Landsat 8 images
    
    Args:
        item_id: string STAC Item ID for a source Sentinel-2 image from BigEarthNet dataset
        
    Returns:
        DataArray of Landsat 8 RGB bands queried from MLHub and Planetary Computer APIs
    """
    # fetch Sentinel-2 item based on item ID
    s2_item = MLHUB_CATALOG.search(
        collections=[BIGEARTHNET_SOURCE_COLLECTION],
        ids=[item_id]
    ).get_all_items()[0]
    
    # fetch Landsat 8 item based on source item metadata
    l8_item = MSPC_CATALOG.search(
        collections=PLANETARY_COMPUTER_LANDSAT_8,
        bbox=s2_item.bbox,
        datetime=temporal_buffer(s2_item.datetime, DATE_BUFFER),
    ).get_all_items()[0]
    
    # fetch landsat 8 RGB bands based on S2 footprint
    l8_stack = stackstac.stack(
        items=pystac.ItemCollection([planetary_computer.sign(l8_item)]),
        assets=["SR_B4", "SR_B3", "SR_B2"],
        bounds_latlon=s2_item.bbox,
        resolution=10,
    )

    return l8_stack
  
### iterative method to benchmark against Dask method
# result_stack = []
# for item_id in SAMPLE_ITEM_IDS:
#     result_stack.append(fetch_xarray(item_id))
# result_stack

### Dask method using delayed futures
task_pool = []

for item_id in SAMPLE_ITEM_IDS:
    delayed_task = dask.delayed(fetch_xarray)(item_id)
    task_pool.append(delayed_task)
    
task_pool = dask.persist(*task_pool)
computed_result = dask.compute(*task_pool) # this is where the CancelledError occurs when running on 5000+ items

Full Traceback:

---------------------------------------------------------------------------
CancelledError                            Traceback (most recent call last)
File <timed exec>:2, in <module>

File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/dask/base.py:288, in DaskMethodsMixin.compute(self, **kwargs)
    264 def compute(self, **kwargs):
    265     """Compute this dask collection
    266 
    267     This turns a lazy Dask collection into its in-memory equivalent.
   (...)
    286     dask.base.compute
    287     """
--> 288     (result,) = compute(self, traverse=False, **kwargs)
    289     return result

File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/dask/base.py:571, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    568     keys.append(x.__dask_keys__())
    569     postcomputes.append(x.__dask_postcompute__())
--> 571 results = schedule(dsk, keys, **kwargs)
    572 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/distributed/client.py:2671, in Client.get(self, dsk, keys, workers, allow_other_workers, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2615 def get(
   2616     self,
   2617     dsk,
   (...)
   2629     **kwargs,
   2630 ):
   2631     """Compute dask graph
   2632 
   2633     Parameters
   (...)
   2669     Client.compute : Compute asynchronous collections
   2670     """
-> 2671     futures = self._graph_to_futures(
   2672         dsk,
   2673         keys=set(flatten([keys])),
   2674         workers=workers,
   2675         allow_other_workers=allow_other_workers,
   2676         resources=resources,
   2677         fifo_timeout=fifo_timeout,
   2678         retries=retries,
   2679         user_priority=priority,
   2680         actors=actors,
   2681     )
   2682     packed = pack_data(keys, futures)
   2683     if sync:

File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/distributed/client.py:2596, in Client._graph_to_futures(self, dsk, keys, workers, allow_other_workers, priority, user_priority, resources, retries, fifo_timeout, actors)
   2594 # Pack the high level graph before sending it to the scheduler
   2595 keyset = set(keys)
-> 2596 dsk = dsk.__dask_distributed_pack__(self, keyset, annotations)
   2598 # Create futures before sending graph (helps avoid contention)
   2599 futures = {key: Future(key, self, inform=False) for key in keyset}

File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/dask/highlevelgraph.py:1076, in HighLevelGraph.__dask_distributed_pack__(self, client, client_keys, annotations)
   1070 layers = []
   1071 for layer in (self.layers[name] for name in self._toposort_layers()):
   1072     layers.append(
   1073         {
   1074             "__module__": layer.__module__,
   1075             "__name__": type(layer).__name__,
-> 1076             "state": layer.__dask_distributed_pack__(
   1077                 self.get_all_external_keys(),
   1078                 self.key_dependencies,
   1079                 client,
   1080                 client_keys,
   1081             ),
   1082             "annotations": layer.__dask_distributed_annotations_pack__(
   1083                 annotations
   1084             ),
   1085         }
   1086     )
   1087 return {"layers": layers}

File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/dask/highlevelgraph.py:401, in Layer.__dask_distributed_pack__(self, all_hlg_keys, known_key_dependencies, client, client_keys)
    397         raise ValueError(
    398             "Inputs contain futures that were created by another client."
    399         )
    400     if stringify(future.key) not in client.futures:
--> 401         raise CancelledError(stringify(future.key))
    403 # Calculate dependencies without re-calculating already known dependencies
    404 # - Start with known dependencies
    405 dependencies = known_key_dependencies.copy()

CancelledError: create_landsat_8_dataarray-99097cdb-3ce8-4543-a1de-47ff6bf8cdc5

Hi @KennSmith

Welcome!
Unfortunately I’m not able to reproduce these errors that you encounter.
Here is a list of my conda environment’s packages for you to check if yours is similar:

> conda list "dask|xarray|planetary-computer|stackstac|pystac|pystac-client"
# packages in environment at C:\Users\runner\anaconda3\envs\daskenv:
#
# Name                    Version                   Build  Channel
dask                      2022.5.0           pyhd8ed1ab_0    conda-forge
dask-core                 2022.5.0           pyhd8ed1ab_0    conda-forge
dask-image                2021.12.0          pyhd8ed1ab_0    conda-forge
planetary-computer        0.4.6              pyhd8ed1ab_0    conda-forge
pystac                    1.2.0              pyhd8ed1ab_0    conda-forge
pystac-client             0.3.2              pyhd8ed1ab_0    conda-forge
stackstac                 0.4.1              pyhd8ed1ab_0    conda-forge
xarray                    2022.3.0           pyhd8ed1ab_0    conda-forge