Kernel Crashes on .compute()

I am running a BERT model (413MB) on a text file (440MB) read into a dask bag. The detect_task function runs the model on each item in the bag. However, the kernel crashes when .compute() is called. Any help is greatly appreciated

cluster = gateway.connect('daskhub.3e427e2debd2421595a53fee61632972')
client = distributed.Client(cluster)

model = torch.load("dangermodel", map_location=torch.device("cpu"))
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

storage_options = {'account_name': keys['account_name'], 'account_key':keys['account_key']}
breviews = db.read_text('abfs://negative-reviews/neg_reviews_2.txt', storage_options = storage_options, blocksize = '5MB')

def detect_task(review):
    review = review.strip('\n')
    rev_encoded = tokenizer(review, return_tensors='pt', padding=True, truncation=True)
    outputs = model(rev_encoded['input_ids'])
    danger_level = float(outputs[0][0][1])
    return (review, danger_level)

reviews_bag =
results = reviews_bag.compute()

Hi @vnguye65 thanks for the question and welcome to discourse!

When you say the kernel crashes, do you mean the local kernel (i.e. on a Jupyter notebook) or a worker crashes? If you’re referring to a local kernel, this could mean that results is too large (there is too much data to send back, which causes the crash). Are there any warning messages or a traceback you can share? Or perhaps a minimally reproducible example?

Hi @scharlottej13,
Thank you for your response. The kernel crashes as soon as .compute() is run and there doesn’t seem to be any tasks being run on the dashboard. The screenshot below shows the error message “Killed” in IPython

Hi @vnguye6, thanks for providing more information! It does seem like this is a memory error, but I have a couple questions to confirm:

  1. If instead of reviews_bag.compute() you use reviews_bag.take(5) (this is similar to a df.head() in pandas, see these docs), do you still see the Killed error?
  2. You noted breviews is 440MB, is this the size on disk? If so, what is the size in memory? How many items are in breviews?

If indeed it is a memory issue, then you may have to split and write results to disk rather than loading it all at once.

I’m able to get the first 5 items in reviews_bag using reviews_bag.take(5).
Could you please expand more on split and write?
Thank you

If we’re assuming this is a memory problem, then you won’t be able to load results in memory all at once. It seems this is the case, especially since you can load the first 5 elements of reviews_bag.

The way this is usually handled is by splitting up results and writing individual files. In the docs I linked there are examples of how to do this with Dask bag, or you can convert the bag to a Dask DataFrame. Given the snippet you shared, you could start with writing review within each iteration of detect_task and then adjust from there based on how much memory you have available.

Hi @scharlottej13, thank you for the suggestion. Another issue with this is that the only way to avoid the kernel crashing is to only use 1 partition thus have only one worker performing all the calculations. When the dask bag is split into more than 1 partition, the kernel crashes immediately because it seems like it cannot load in the model of that size all at once into different workers.

@vnguye65 that certainly sounds challenging! It would really help if you could share a minimal reproducible example that is copy-pasteable and run using a local cluster. The model and data in your example can be a simplification of your actual model and data. Additionally, knowing more about your cluster setup and in particular the available resources (e.g. number of workers, memory available) would really help.

I ran into something very similar, and set the scheduler to processes on the client resolved it. I still haven’t figured out the root cause of the memory leak with the multiprocessing scheduler.