Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT REVIEW] --experimental.fsdp_sharding_on_largest_dim #607

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ def __init__(self):
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--experimental.fsdp_sharding_on_largest_dim",
action="store_true",
help="""
sharding on largest dim to reduce padding
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
20 changes: 19 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def parallelize_llama(
],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
shard_on_largest_dim=job_config.experimental.fsdp_sharding_on_largest_dim,
)
if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
Expand Down Expand Up @@ -299,12 +300,29 @@ def apply_fsdp(
reduce_dtype: torch.dtype,
tp_enabled: bool,
pp_enabled: bool,
shard_on_largest_dim: bool,
):
"""
Apply data parallelism to the model. FSDP2 is used here.
"""

def shard_placement_fn(param: nn.Parameter):
largest_dim = -1
largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
assert largest_dim >= 0, f"{param.shape}"
assert largest_dim < param.ndim, f"{largest_dim=} {param.shape}"
return Shard(largest_dim)

mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
fsdp_config = {
"mesh": dp_mesh,
"mp_policy": mp_policy,
"shard_placement_fn": shard_placement_fn if shard_on_largest_dim else None,
}

# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
Expand Down
Loading