NOTE: It will be difficult to provide a good reprex for this issue as it requires an API key to an open-source service, as well as knowledge of the STAC ecosystem and the corresponding pystac library. I have done my best to simplify the sample while still retaining key functional elements of the original notebook.
Problem:
I am running into a Dask CancelledError
when running the compute
function on a collection of delayed objects. The function below wraps the process of fetching a DataArray of satellite imagery based on a STAC item ID. When running the code below on the 3 sample item IDs, there is no error. However, when I am running on a random sample of 5,000 out of 500,000 total STAC items, the error is raised fairly quickly.
Is there a way to catch and log the results of computation on the Dask delayed objects when there is an error, to check if an asset being fetched by the pipeline is failing, or if it’s an issue with Dask itself?
My assumption is the issue is within Dask because I took the exact same wrapper function and mapped it using dask.bag
and it did not raise the CancelledError
I was running into using dask.delayed
. Any thoughts or suggestions would be greatly appreciated.
For practical purposes create_landsat_8_dataarray
mentioned in the raised error traceback is interchangeable with fetch_xarray
.
Code:
import dask
import xarray
import pystac
import pystac_client
import stackstac
import planetary_computer
import datetime
# API endpoints for MLHub and Planetary Computer catalogs
MLHUB_API_URL = "https://api.radiant.earth/mlhub/v1"
MSPC_API_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"
MLHUB_API_KEY = "77588a47fe96febbbf270d89b47743600cb296290238471dde9d6784815ac398" # this is a temporary key created for this forum, the data and registration is already open to the public
# Create a connection to the MLHub STAC API endpoint
MLHUB_CATALOG = pystac_client.Client.open(
url=MLHUB_API_URL,
parameters={"key": MLHUB_API_KEY}
)
# Create a connection to the Microsoft Planetary Computer API endpoint
MSPC_CATALOG = pystac_client.Client.open(MSPC_API_URL)
# Names of Collections that will be queried against using pystac_client
BIGEARTHNET_SOURCE_COLLECTION = "bigearthnet_v1_source" # sentinel-2 source imagery
BIGEARTHNET_LABEL_COLLECTION = "bigearthnet_v1_labels" # geojson classification labels
PLANETARY_COMPUTER_LANDSAT_8 = "landsat-8-c2-l2" # landsat 8 source imagery on PC
DATE_BUFFER = 60 # date range to add before and after a item datetime
SAMPLE_ITEM_IDS = [
"bigearthnet_v1_source_S2B_MSIL2A_20180511T100029_14_9",
"bigearthnet_v1_source_S2B_MSIL2A_20180421T114349_54_81",
"bigearthnet_v1_source_S2A_MSIL2A_20170613T101031_24_58"
]
def temporal_buffer(item_dt: str, date_delta: int) -> str:
"""Takes a datetime object and returns a buffer around that date
Args:
item_datetime: datetime property from an Item
date_delta: integer for days to add before and after a date
Returns:
a string range representing the full date buffer
"""
delta = td(days=date_delta)
# item_dt = datetime.strptime(item_datetime, "%Y-%m-%dT%H:%M:%SZ")
dt_start = item_dt - delta
dt_start_str = dt_start.strftime("%Y-%m-%d")
dt_end = item_dt + delta
dt_end_str = dt_end.strftime("%Y-%m-%d")
return f"{dt_start_str}/{dt_end_str}"
def fetch_xarray(item_id: str) -> xarray.DataArray:
"""Takes the sring STAC Item ID to return a DataArray of Landsat 8 images
Args:
item_id: string STAC Item ID for a source Sentinel-2 image from BigEarthNet dataset
Returns:
DataArray of Landsat 8 RGB bands queried from MLHub and Planetary Computer APIs
"""
# fetch Sentinel-2 item based on item ID
s2_item = MLHUB_CATALOG.search(
collections=[BIGEARTHNET_SOURCE_COLLECTION],
ids=[item_id]
).get_all_items()[0]
# fetch Landsat 8 item based on source item metadata
l8_item = MSPC_CATALOG.search(
collections=PLANETARY_COMPUTER_LANDSAT_8,
bbox=s2_item.bbox,
datetime=temporal_buffer(s2_item.datetime, DATE_BUFFER),
).get_all_items()[0]
# fetch landsat 8 RGB bands based on S2 footprint
l8_stack = stackstac.stack(
items=pystac.ItemCollection([planetary_computer.sign(l8_item)]),
assets=["SR_B4", "SR_B3", "SR_B2"],
bounds_latlon=s2_item.bbox,
resolution=10,
)
return l8_stack
### iterative method to benchmark against Dask method
# result_stack = []
# for item_id in SAMPLE_ITEM_IDS:
# result_stack.append(fetch_xarray(item_id))
# result_stack
### Dask method using delayed futures
task_pool = []
for item_id in SAMPLE_ITEM_IDS:
delayed_task = dask.delayed(fetch_xarray)(item_id)
task_pool.append(delayed_task)
task_pool = dask.persist(*task_pool)
computed_result = dask.compute(*task_pool) # this is where the CancelledError occurs when running on 5000+ items
Full Traceback:
---------------------------------------------------------------------------
CancelledError Traceback (most recent call last)
File <timed exec>:2, in <module>
File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/dask/base.py:288, in DaskMethodsMixin.compute(self, **kwargs)
264 def compute(self, **kwargs):
265 """Compute this dask collection
266
267 This turns a lazy Dask collection into its in-memory equivalent.
(...)
286 dask.base.compute
287 """
--> 288 (result,) = compute(self, traverse=False, **kwargs)
289 return result
File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/dask/base.py:571, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
568 keys.append(x.__dask_keys__())
569 postcomputes.append(x.__dask_postcompute__())
--> 571 results = schedule(dsk, keys, **kwargs)
572 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/distributed/client.py:2671, in Client.get(self, dsk, keys, workers, allow_other_workers, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
2615 def get(
2616 self,
2617 dsk,
(...)
2629 **kwargs,
2630 ):
2631 """Compute dask graph
2632
2633 Parameters
(...)
2669 Client.compute : Compute asynchronous collections
2670 """
-> 2671 futures = self._graph_to_futures(
2672 dsk,
2673 keys=set(flatten([keys])),
2674 workers=workers,
2675 allow_other_workers=allow_other_workers,
2676 resources=resources,
2677 fifo_timeout=fifo_timeout,
2678 retries=retries,
2679 user_priority=priority,
2680 actors=actors,
2681 )
2682 packed = pack_data(keys, futures)
2683 if sync:
File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/distributed/client.py:2596, in Client._graph_to_futures(self, dsk, keys, workers, allow_other_workers, priority, user_priority, resources, retries, fifo_timeout, actors)
2594 # Pack the high level graph before sending it to the scheduler
2595 keyset = set(keys)
-> 2596 dsk = dsk.__dask_distributed_pack__(self, keyset, annotations)
2598 # Create futures before sending graph (helps avoid contention)
2599 futures = {key: Future(key, self, inform=False) for key in keyset}
File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/dask/highlevelgraph.py:1076, in HighLevelGraph.__dask_distributed_pack__(self, client, client_keys, annotations)
1070 layers = []
1071 for layer in (self.layers[name] for name in self._toposort_layers()):
1072 layers.append(
1073 {
1074 "__module__": layer.__module__,
1075 "__name__": type(layer).__name__,
-> 1076 "state": layer.__dask_distributed_pack__(
1077 self.get_all_external_keys(),
1078 self.key_dependencies,
1079 client,
1080 client_keys,
1081 ),
1082 "annotations": layer.__dask_distributed_annotations_pack__(
1083 annotations
1084 ),
1085 }
1086 )
1087 return {"layers": layers}
File ~/opt/anaconda3/envs/mlhub/lib/python3.9/site-packages/dask/highlevelgraph.py:401, in Layer.__dask_distributed_pack__(self, all_hlg_keys, known_key_dependencies, client, client_keys)
397 raise ValueError(
398 "Inputs contain futures that were created by another client."
399 )
400 if stringify(future.key) not in client.futures:
--> 401 raise CancelledError(stringify(future.key))
403 # Calculate dependencies without re-calculating already known dependencies
404 # - Start with known dependencies
405 dependencies = known_key_dependencies.copy()
CancelledError: create_landsat_8_dataarray-99097cdb-3ce8-4543-a1de-47ff6bf8cdc5