Skip to content

Commit

Permalink
Merge pull request #582 from allenai/shanea/hybrid-shard-as-no-shard
Browse files Browse the repository at this point in the history
Allow hybrid sharding to have multiple replicas in a node
  • Loading branch information
2015aroras authored May 17, 2024
2 parents c29787a + f05bf9b commit e6430a0
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
if num_model_replicas <= 0:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must be a positive integer")

num_nodes = get_world_size() // get_local_world_size()
if num_nodes > 1 and num_nodes % num_model_replicas != 0:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide number of nodes")
if get_world_size() % num_model_replicas != 0:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide world size")

device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas))
hybrid_sharding_fsdp_kwargs["device_mesh"] = device_mesh
Expand Down

0 comments on commit e6430a0

Please sign in to comment.