Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements and fixes to gradient accumulation #993

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apoorvtintin
Copy link
Contributor

@apoorvtintin apoorvtintin commented Feb 14, 2025

  • Fix to with_minibatch_steps decorator to generate correct primal outputs shapes.
  • Improved with_minibatch_steps to take a minibatch_partitioner that constraints the accumulation minibatch to the same PartitionSpec as input_partitioner.

Misc:

  • Enable gradient accumulation for Fuji 3B on TRN2

@apoorvtintin apoorvtintin requested review from ruomingp, markblee and a team as code owners February 14, 2025 01:10
# Note: the batch axes are different here than in
# `cfg.batch_axis_names`,
# as we partition sequence dim over `seq`.
(None, 1): PartitionSpec(("data", "expert", "fsdp")),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering, if we have a default input partition with axis=0 on ("data", "expert", "fsdp") and axis=1 on "seq", do we still need this?

Copy link
Contributor Author

@apoorvtintin apoorvtintin Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick review.
(None, 1) is for the target_num_bytes and (None, 2) is for the input_ids and target_labels, so we need both. Together they will work for most cases, but for the outliers where a specific sharding is required the ability to change sharding for the minibatches will be good to have.

Let me know if this answers your question.

),
input_partition_spec(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, it seems rather a hack than a proper solution, that is, we want to have a different input_partition_spec() than the default one, then we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed the default case, added it.

I think the below partition spec is good as a default, but the ability to change PartitionSpec might be good to have, what do you think?

(None, 1): PartitionSpec(("data", "expert", "fsdp")),
(None, 2): PartitionSpec(("data", "expert", "fsdp"), "seq"), 

@apoorvtintin apoorvtintin force-pushed the mainline_grad_accum_fix branch 2 times, most recently from 9b0f9a3 to 32a78ea Compare February 14, 2025 23:15
- Fix to with_minibatch_steps decorator to generate correct primal outputs shapes.
- Improved with_minibatch_steps to take a minibatch_partitioner that contraints the input batch to the same PartitionSpec as Input Partitioner.
@@ -57,39 +59,38 @@ def _make_scan_minibatch_inputs(
param_noise_key: Tensor,
minibatch_size: int,
minibatch_index: int,
minibatch_partitioner: Optional[InputPartitionFn],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Echoing Kelvin's comment, could you explain concretely why we need this functionality? If it's just something that might be useful, maybe we can wait until we are certain that we will need it?

Copy link
Contributor Author

@apoorvtintin apoorvtintin Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case where gradient accumulation is not enabled, the inputs to the graph are sharded as per the policy in input_partitioner. This ensures the batch dimension is sharded on data, expert and fsdp axes while sequence dimension is replicated on model axis.

Gradient accumulation wraps the train steps in a scan loop, while the input_partitioner shards the input batch to correctly at first. In the gradient accumulation wrapper the input batches are resharded/overridden by the function _make_scan_minibatch_inputs and sharded along all axes available which is probably unexpected and inefficient. Minibatches should follow the same PartitionSpec as input_batches.

The addition of the minibatch_partitioner allows the minibatches to use the same sharding/PartitionSpec as input_partitioner provides in the input batches in the case gradient accumulation is not used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we just preserve the sharding the input already has, would that also address the concern about the input sharding being changed?

# Default partitioner for minibatches.
if not minibatch_partitioner:
minibatch_partitioner = partition_by_path_rank(
path_rank_to_partition={
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we default this to the same sharding the input is already using along all non-batch axes?

Copy link
Contributor Author

@apoorvtintin apoorvtintin Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming if I read it correctly, we want to default to input_partition_specs from utils.py like before, and not what the input_partitioner sets.

Or the ask is to use the partition_by_path_rank to replicate what input_partition_specs was doing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. I was envisioning that for all axes other than axis 0, we default to whatever sharding the input already has. For axis 0, ideally we could also keep whatever sharding the input already has too, although I'm not sure that would work with logical batching.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants