-
Notifications
You must be signed in to change notification settings - Fork 235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Used per-parameter FSDP #70
Conversation
) | ||
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} | ||
for layer_id, transformer_block in enumerate(model.layers): | ||
transformer_block = checkpoint_wrapper(transformer_block, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we could get _composable.checkpoint
to work with compile / selective AC, then we could replace with
checkpoint(transformer_block)
and not require the assignment model.layers[layer_id] = transformer_block
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm - exciting to see this being rolled out.
left two minor nits. (not sure if 7B as default intentional, maybe at least a TODO marking that clip grad is disabled and why, instead of simply commenting out).
56f1ae0
to
cb0f563
Compare
27f3286
to
f33edba
Compare
@@ -152,6 +147,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |||
), | |||
}, | |||
) | |||
distribute_rmsnorm(model.norm, tp_mesh) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @wanchaol @tianyu-l:
FSDP2 currently checks that if it uses a mesh that has a parent mesh, then the parameter to shard should be a DTensor
. However, without this distribute_rmsnorm(model.norm, tp_mesh)
, the Transformer.norm.weight
is a torch.Tensor
.
Would the sharded state dict for existing FSDP be incorrect without making it a DTensor
? Should we assert that all parameters in the model are DTensor
after parallelizing with TP/SP, or is that not a requirement? (If not a requirement, then we may need to relax the check on FSDP2 side.)
This PR shows how we would modify torchtrain to use per-parameter FSDP (sometimes called FSDP2 or ppFSDP). Detailed context can be found in pytorch/pytorch#114299.
Progress Tracking
Gaps for 1D FSDP:
DTensor
's op dispatch to implementvector_norm
andstack
. This approach will allow us to usetorch.nn.utils.clip_grad_norm_()
out of the box as long as allDTensor
parameters share the same sharding.foreach=True
: [DTensor] Supportedforeach=False
forclip_grad_norm_
pytorch#120238foreach=True
model.to_empty(device="cuda")
and initialize parameters/buffers as needed (e.g. viareset_parameters()
on each module). The initialization will useDTensor
's randomness, which ensures correct randomness with respect to the global shape.Goal is to land PRs to address these gaps by 3/8. The overall execution tracker can be found in pytorch/pytorch#120003.
Gaps for 2D FSDP + TP/SP:
torch.compile
Experiments
LLama2-7B
8 H100 GPUs, local batch size 8, sequence length 2048, activation checkpointing and
torch.compile
each transformer blockFULL_SHARD
):reshard_after_forward=True
)reshard_after_forward=False
for last transformer block)(Ideally, we can have some per-GPU MFU or tokens-per-second metric rather than time per iteration, which is less interpretable.)