Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid sharding with specific number of replicas #540

Merged
merged 5 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Makes it possible to read from http/https the same way we read from s3/r2.
- Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks
- Tokenizer patch
- Added option to specify number of model replicas when using hybrid sharding.

### Changed

Expand All @@ -35,7 +36,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed the size calculation for qk layer norm
- Fixed pipeline test failure that occurs due to a bug in transformers version 4.39.1


## [v0.2.5](https://github.com/allenai/OLMo/releases/tag/v0.2.5) - 2024-03-06

### Fixed
Expand Down
8 changes: 8 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,14 @@ class FSDPConfig(BaseConfig):

precision: FSDPPrecision = FSDPPrecision.pure

hybrid_sharding_num_model_replicas: Optional[int] = None
"""
The number of model instances, when using a hybrid sharding strategy.
If not ``None``, this must divide the total number of nodes. If ``None``, the default,
a model instance is used per node (as determined by ``get_world_size() // get_local_world_size()``).
PyTorch's default HSDP behavior matches this default behavior.
"""


class CheckpointType(StrEnum):
sharded = "sharded"
Expand Down
27 changes: 27 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import torch.multiprocessing as mp
import wandb
from packaging import version
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

from olmo.config import CheckpointType, TrainConfig
from olmo.data import build_train_dataloader
Expand All @@ -24,6 +26,7 @@
get_default_device,
get_global_rank,
get_local_rank,
get_local_world_size,
get_world_size,
peak_gpu_memory,
seed_all,
Expand Down Expand Up @@ -133,8 +136,32 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
param_init_fn = dummy_init_fn
else:
param_init_fn = None

# Set up device mesh for hybrid sharding in order to specify which nodes are assoicated to a given model replica
device_mesh: Optional[DeviceMesh] = None
if cfg.fsdp.sharding_strategy in (ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2):
if version.parse(torch.__version__) < version.parse("2.2.0"):
# Device mesh was not added to PyTorch until v2.2.0
raise OLMoConfigurationError(
"OLMo training does not correctly support hybrid sharding before torch 2.2.0"
)

num_model_replicas = cfg.fsdp.hybrid_sharding_num_model_replicas or (
get_world_size() // get_local_world_size()
)

if num_model_replicas <= 0:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must be a positive integer")

num_nodes = get_world_size() // get_local_world_size()
if num_nodes % num_model_replicas != 0:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide number of nodes")

device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas))

fsdp_model = FSDP(
olmo_model,
device_mesh=device_mesh,
sharding_strategy=cfg.fsdp.sharding_strategy,
mixed_precision=cfg.fsdp_precision,
auto_wrap_policy=wrap_policy,
Expand Down
Loading