FileNotFoundError when reading input in distributed cluster

Hello, is there any example showing how to correctly setup XGBoost distributed training on Dask cluster?

I’m using AWS SageMaker training. I’m starting a scheduler on one of the EC2 instances, and then starting workers(dask-cuda-worker) on all instances. Then I create a client to connect to the scheduler, read parquet input, before starting training. The input parquet is in S3, ~48GB total, each part around 46MB. I’m using SageMaker’s distribution=“ShardedByS3Key” feature, which splits the parquet part files between the instances. However, I get errors like the following:

2022-12-28 18:23:46,163 - distributed.worker - WARNING - Compute Failed
Key:       ('read-parquet-4c1cfa0cfd4211f054320374366c7a88', 95)
Function:  subgraph_callable-f8208900-18f8-4845-b8e5-e73533ce
args:      ({'piece': ('/opt/ml/input/data/train/part.382.parquet', None, None)})
kwargs:    {}
Exception: "FileNotFoundError(2, 'No such file or directory')"

Looks like the worker tries to load files which aren’t present in that Ec2 instance. When I tried printing client.has_what(), I see the following:

{‘tcp://10.0.126.177:38723’: (“(‘getitem-bdead9887a4155209c76c6f5db0c8dfa’, 0)”, “(‘getitem-cb025737c36a128e1ed35fbea35e9f9d’, 0)”, …, “(‘getitem-1f0bac67347de86fc1ac691363232706’, 203)”, “(‘getitem-540d5c7fa3c021728445e4efed1d893a’, 203)”), ‘tcp://10.0.140.140:40423’: (), ‘tcp://10.0.154.212:44707’: (), ‘tcp://10.0.169.177:42301’: ()}

I have 4 workers(in 4 EC2 instances), so clearly some workers aren’t getting any tasks. I also tried distribution=“FullyReplicated” to have all the parts present in each instance, but I get out of memory errors, even when using a lot of GPUs. I think my code isn’t utilizing Dask to read partial data on each worker.

def start_daemons(master_ip, current_host_ip):
    cmd_start_scheduler = os.path.join(DASK_PATH, "dask-scheduler")
    cmd_start_worker = os.path.join(DASK_PATH, "dask-cuda-worker")
    schedule_conn_string = "tcp://{ip}:".format(ip=master_ip) + str(SCHEDULER_PORT)
    if master_ip == current_host_ip:
        Popen([cmd_start_scheduler])
        Popen([cmd_start_worker, "--no-dashboard", schedule_conn_string])
    else:
        Popen([cmd_start_worker, "--no-dashboard", schedule_conn_string])

def load_train_and_validation_parquet(train_path, valid_path):
    train_df = dask_cudf.read_parquet(train_path)
    valid_df = dask_cudf.read_parquet(valid_path)
    
    y_train = train_df["label"]
    X_train = train_df[train_df.columns.difference(["label"])]
    
    y_valid = valid_df["label"]
    X_valid = valid_df[valid_df.columns.difference(["label"])]
    
    X_train, X_valid, y_train, y_valid = client.persist(
        [X_train, X_valid, y_train, y_valid]
    )
    wait([X_train, X_valid, y_train, y_valid])

    print(client.has_what())

    return X_train, X_valid, y_train, y_valid

...

with Client(scheduler_addr) as client:

            X_train, X_valid, y_train, y_valid = load_train_and_validation_parquet(args.train, args.validation)

            dtrain = xgb.dask.DaskDMatrix(client, X_train, y_train)
            dvalid = xgb.dask.DaskDMatrix(client, X_valid, y_valid)
            output = xgb.dask.train(client,
                            train_hp,
                            dtrain,
                            num_boost_round=args.num_round,
                            evals=[(dvalid, 'Valid')])

            if current_host_ip == master_host_ip:
                booster = output['booster']  # booster is the trained model
                history = output['history']  # A dictionary containing evaluation results
                booster.save_model(args.model_dir + '/xgboost-model')
                print('Training evaluation history:', history)

Thanks in advance!

Hi @cd2202,

I don’t know how SageMaker works, but I guess that Dask is not aware of how it shard the S3 inputs on the instances, so it doesn’t take it into account, which would exaplin the error you’re seeing.

In this code, it looks like you’re trying to persist the entire dataset into memory. Do you have enough memory for this?