From 44a581b57ff6b4053b7cd3e793029d5eda3e1b79 Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 17 May 2024 13:58:40 -0700 Subject: [PATCH] Allow hybrid sharding to have multiple replicas in a node --- scripts/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 2f4235482..5d8106a75 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -160,9 +160,8 @@ def dummy_init_fn(module: torch.nn.Module) -> None: if num_model_replicas <= 0: raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must be a positive integer") - num_nodes = get_world_size() // get_local_world_size() - if num_nodes > 1 and num_nodes % num_model_replicas != 0: - raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide number of nodes") + if get_world_size() % num_model_replicas != 0: + raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide world size") device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas)) hybrid_sharding_fsdp_kwargs["device_mesh"] = device_mesh