From c9ceb5c28a438b03ac3a1442138b60a1fe9dd4ae Mon Sep 17 00:00:00 2001 From: Shane A Date: Tue, 9 Apr 2024 15:06:08 -0700 Subject: [PATCH 1/3] Add config for model replicas in hybrid sharding --- olmo/config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/olmo/config.py b/olmo/config.py index 042c704ce..4ece678fd 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -681,6 +681,14 @@ class FSDPConfig(BaseConfig): precision: FSDPPrecision = FSDPPrecision.pure + hybrid_sharding_num_model_replicas: Optional[int] = None + """ + The number of model instances, when using a hybrid sharding strategy. + If not ``None``, this must divide the total number of nodes. If ``None``, the default, + a model instance is used per node (as determined by ``get_world_size() // get_local_world_size()``). + PyTorch's default HSDP behavior matches this default behavior. + """ + class CheckpointType(StrEnum): sharded = "sharded" From aea725171a66271aeb840c68444e972b1b5a098b Mon Sep 17 00:00:00 2001 From: Shane A Date: Tue, 9 Apr 2024 15:09:47 -0700 Subject: [PATCH 2/3] Use device mesh to hybrid shard across nodes --- scripts/train.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/scripts/train.py b/scripts/train.py index f93734c0b..8bf2cd097 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -11,7 +11,9 @@ import torch.multiprocessing as mp import wandb from packaging import version +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy from olmo.config import CheckpointType, TrainConfig from olmo.data import build_train_dataloader @@ -24,6 +26,7 @@ get_default_device, get_global_rank, get_local_rank, + get_local_world_size, get_world_size, peak_gpu_memory, seed_all, @@ -133,8 +136,32 @@ def dummy_init_fn(module: torch.nn.Module) -> None: param_init_fn = dummy_init_fn else: param_init_fn = None + + # Set up device mesh for hybrid sharding in order to specify which nodes are assoicated to a given model replica + device_mesh: Optional[DeviceMesh] = None + if cfg.fsdp.sharding_strategy in (ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2): + if version.parse(torch.__version__) < version.parse("2.2.0"): + # Device mesh was not added to PyTorch until v2.2.0 + raise OLMoConfigurationError( + "OLMo training does not correctly support hybrid sharding before torch 2.2.0" + ) + + num_model_replicas = cfg.fsdp.hybrid_sharding_num_model_replicas or ( + get_world_size() // get_local_world_size() + ) + + if num_model_replicas <= 0: + raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must be a positive integer") + + num_nodes = get_world_size() // get_local_world_size() + if num_nodes % num_model_replicas != 0: + raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide number of nodes") + + device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas)) + fsdp_model = FSDP( olmo_model, + device_mesh=device_mesh, sharding_strategy=cfg.fsdp.sharding_strategy, mixed_precision=cfg.fsdp_precision, auto_wrap_policy=wrap_policy, From 4f356323b67c70566b13794e10e615ba495e81af Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 15 Apr 2024 10:46:25 -0700 Subject: [PATCH 3/3] Update CHANGELOG for hybrid sharding model replicas --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8617f2f90..9f4bf369a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Makes it possible to read from http/https the same way we read from s3/r2. - Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks - Tokenizer patch +- Added option to specify number of model replicas when using hybrid sharding. ### Changed @@ -35,7 +36,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed the size calculation for qk layer norm - Fixed pipeline test failure that occurs due to a bug in transformers version 4.39.1 - ## [v0.2.5](https://github.com/allenai/OLMo/releases/tag/v0.2.5) - 2024-03-06 ### Fixed