Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Used per-parameter FSDP #70

Closed
wants to merge 1 commit into from
Closed

[WIP] Used per-parameter FSDP #70

wants to merge 1 commit into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Feb 23, 2024

This PR shows how we would modify torchtrain to use per-parameter FSDP (sometimes called FSDP2 or ppFSDP). Detailed context can be found in pytorch/pytorch#114299.

Progress Tracking
Gaps for 1D FSDP:

  • Support gradient norm clipping
    • The plan is to rely on DTensor's op dispatch to implement vector_norm and stack. This approach will allow us to use torch.nn.utils.clip_grad_norm_() out of the box as long as all DTensor parameters share the same sharding.
    • Gradient norm clipping w/o foreach=True: [DTensor] Supported foreach=False for clip_grad_norm_ pytorch#120238
    • Gradient norm clipping w/ foreach=True
  • Integrate with distributed checkpointing
  • Meta-device initialization
    • The plan is to allow FSDP to shard on meta device without materializing parameters. After applying FSDP, the user can call model.to_empty(device="cuda") and initialize parameters/buffers as needed (e.g. via reset_parameters() on each module). The initialization will use DTensor's randomness, which ensures correct randomness with respect to the global shape.
    • There are some minor gaps in core that will be fixed soon.

Goal is to land PRs to address these gaps by 3/8. The overall execution tracker can be found in pytorch/pytorch#120003.

Gaps for 2D FSDP + TP/SP:

  • Support torch.compile
    • Current error: P1192664996
    • We may need to add FSDP2 to some skipfiles similar to FSDP so that we graph break cleanly on FSDP's pre-forward hook.

Experiments
LLama2-7B
8 H100 GPUs, local batch size 8, sequence length 2048, activation checkpointing and torch.compile each transformer block

  • Baseline: Existing flat-parameter (FULL_SHARD):
    • 2124 ms per iteration
    • 23.24 GB peak active
  • Per-parameter (reshard_after_forward=True)
    • 2143 ms per iteration
    • 22.11 GB peak active
  • Per-parameter (reshard_after_forward=False for last transformer block)
    • 2127 ms per iteration
    • 22.49 GB peak active (if peak is at beginning of backward, then this increase makes sense)

(Ideally, we can have some per-GPU MFU or tokens-per-second metric rather than time per iteration, which is less interpretable.)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 23, 2024
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
transformer_block = checkpoint_wrapper(transformer_block, args)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we could get _composable.checkpoint to work with compile / selective AC, then we could replace with

checkpoint(transformer_block)

and not require the assignment model.layers[layer_id] = transformer_block.

run_llama_train.sh Outdated Show resolved Hide resolved
train.py Outdated Show resolved Hide resolved
Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm - exciting to see this being rolled out.
left two minor nits. (not sure if 7B as default intentional, maybe at least a TODO marking that clip grad is disabled and why, instead of simply commenting out).

@awgu awgu force-pushed the per_param branch 6 times, most recently from 56f1ae0 to cb0f563 Compare February 29, 2024 15:58
@awgu awgu force-pushed the per_param branch 4 times, most recently from 27f3286 to f33edba Compare March 5, 2024 17:01
@@ -152,6 +147,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
),
},
)
distribute_rmsnorm(model.norm, tp_mesh)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @wanchaol @tianyu-l:
FSDP2 currently checks that if it uses a mesh that has a parent mesh, then the parameter to shard should be a DTensor. However, without this distribute_rmsnorm(model.norm, tp_mesh), the Transformer.norm.weight is a torch.Tensor.

Would the sharded state dict for existing FSDP be incorrect without making it a DTensor? Should we assert that all parameters in the model are DTensor after parallelizing with TP/SP, or is that not a requirement? (If not a requirement, then we may need to relax the check on FSDP2 side.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants