-
Notifications
You must be signed in to change notification settings - Fork 459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Unify activation checkpointing APIs #2114
Comments
Overall - good. Support. Nice. Some comments:
|
This is very nice. Since this RFC relates to our APIs around activation checkpointing, I wonder what the UX would be like for using and customizing this component in recipes? A couple thoughts I had:
or
ac_conditions:
- torchtune.training.activations.selective_active_checkpointing_condition
- k: 4
- my_repo.only_checkpointing_attention_projections This would involve offering some sensible conditions from torchtune, and making it easy to swap them in and out. |
@joecummings and @SalmanMohammadi thanks for the comments.
Can you clarify what part of
Yeah this should be the same as this code pointer, right? Although I guess this way some hooks fire as well.
Imo if we are gonna properly support SAC it should be doable from the config. I like your example in (2), though I think we may need to support partial instantiate to get the every k layers example to work cleanly. And separately I wanna figure out how we can keep the bool flag to for vanilla checkpoint-every-layer as our first-class citizen. |
Where are we today?
We currently provide two different APIs for activation checkpointing.
set_activation_checkpointing is the default for most of our recipes and its contract is similar to that of FSDP wrapping: the user either provides a set of nn.Module types to wrap or a
Callable[[nn.Module, bool, int], bool]
(see a more detailed description here). In practice we only use the first approach in our recipes, see here.apply_selective_activation_checkpointing was added as a prototype feature. It has only been integrated into two of our recipes (distributed full finetune and distributed QAT) and is only exposed in a single dev config. It is not currently tested.
Currently neither of these APIs provides a superset of functionality of the other.
What needs to change?
Having both these APIs is redundant and potentially confusing. See e.g. here. We should consolidate behind a single, clean, well-tested API.
What are the requirements?
Imo our AC API should definitely support:
However, I claim we do not need to support the second case from the$k^{th}$ layer). Separately, there is op-level SAC as demonstrated in torchtitan here. Imo this is nice-to-have but not must-have as it requires some custom handling.
set_activation_checkpointing
API (I think the Callable contract is a bit confusing and it doesn't actually cover the most common SAC case of checkpointing everyProposal
Assuming this, I propose we take a similar approach to our current shard_model utility.
Then we can address the first case with e.g.
ac_condition = lambda n, m: isinstance(m, TransformerSelfAttentionLayer)
and the second withac_condition = lambda n, m: get_layer_num(n) % k == 0
, whereget_layer_num
is a utility to infer the layer number from the full parameter name.Potential drawbacks of this approach: (1) we maybe (?) need to do some
setattr
magic to handle e.g. this. And (2) is that string parsing may feel a bit hacky to infer layer numbers compared to what we currently do in apply_selective_activation_checkpointing. But imo this is worth it for the increased generality (e.g. that utility assumes that we are applying it to a model having a list of layers as a top-level attribute)The text was updated successfully, but these errors were encountered: