Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Unify activation checkpointing APIs #2114

Open
ebsmothers opened this issue Dec 5, 2024 · 3 comments
Open

[RFC] Unify activation checkpointing APIs #2114

ebsmothers opened this issue Dec 5, 2024 · 3 comments
Assignees
Labels
rfc Request for comments

Comments

@ebsmothers
Copy link
Contributor

ebsmothers commented Dec 5, 2024

Where are we today?

We currently provide two different APIs for activation checkpointing.

  1. set_activation_checkpointing is the default for most of our recipes and its contract is similar to that of FSDP wrapping: the user either provides a set of nn.Module types to wrap or a Callable[[nn.Module, bool, int], bool] (see a more detailed description here). In practice we only use the first approach in our recipes, see here.

  2. apply_selective_activation_checkpointing was added as a prototype feature. It has only been integrated into two of our recipes (distributed full finetune and distributed QAT) and is only exposed in a single dev config. It is not currently tested.

Currently neither of these APIs provides a superset of functionality of the other.

What needs to change?

Having both these APIs is redundant and potentially confusing. See e.g. here. We should consolidate behind a single, clean, well-tested API.

What are the requirements?

Imo our AC API should definitely support:

  • Wrapping a set of nn.Module types (i.e. the first case from (1))
  • Selective activation checkpointing (SAC) of every $k^{th}$ layer

However, I claim we do not need to support the second case from the set_activation_checkpointing API (I think the Callable contract is a bit confusing and it doesn't actually cover the most common SAC case of checkpointing every $k^{th}$ layer). Separately, there is op-level SAC as demonstrated in torchtitan here. Imo this is nice-to-have but not must-have as it requires some custom handling.

Proposal

Assuming this, I propose we take a similar approach to our current shard_model utility.

apply_activation_checkpointing(
	model: nn.Module,
	ac_conditions: List[Callable[str, nn.Module], bool]
):
	for n, m in reversed(list(model.named_modules())):
		if any([ac_condition(n, m) for ac_condition in ac_conditions]):
			# apply AC wrapping

Then we can address the first case with e.g. ac_condition = lambda n, m: isinstance(m, TransformerSelfAttentionLayer) and the second with ac_condition = lambda n, m: get_layer_num(n) % k == 0, where get_layer_num is a utility to infer the layer number from the full parameter name.

Potential drawbacks of this approach: (1) we maybe (?) need to do some setattr magic to handle e.g. this. And (2) is that string parsing may feel a bit hacky to infer layer numbers compared to what we currently do in apply_selective_activation_checkpointing. But imo this is worth it for the increased generality (e.g. that utility assumes that we are applying it to a model having a list of layers as a top-level attribute)

@ebsmothers ebsmothers added the discussion Start a discussion label Dec 5, 2024
@ebsmothers ebsmothers mentioned this issue Dec 5, 2024
4 tasks
@ebsmothers ebsmothers changed the title Unify activation checkpointing APIs [RFC] Unify activation checkpointing APIs Dec 5, 2024
@joecummings
Copy link
Contributor

Overall - good. Support. Nice.

Some comments:

  1. ac_conditions is maybe too loose of a contract to allow into this public API. I might prefer something like an CheckpointingPolicy class, which is more inline with how FSDP2 kinda works and torchtitan uses it as well.
  2. I think you will need something like this

@SalmanMohammadi
Copy link
Collaborator

Then we can address the first case with e.g. ac_condition = lambda n, m: isinstance(m, TransformerSelfAttentionLayer) and the second with ac_condition = lambda n, m: get_layer_num(n) % k == 0, where get_layer_num is a utility to infer the layer number from the full parameter name.

This is very nice.

Since this RFC relates to our APIs around activation checkpointing, I wonder what the UX would be like for using and customizing this component in recipes? A couple thoughts I had:

  1. This is a relatively niche feature (is it?), maybe we don't need a way to configure it from the config. If a user wants to make a change here, they're forking/copying the recipe file.

or

  1. We want people to be able to configure this easily, such as something like:
ac_conditions:
  - torchtune.training.activations.selective_active_checkpointing_condition
     - k: 4
  - my_repo.only_checkpointing_attention_projections

This would involve offering some sensible conditions from torchtune, and making it easy to swap them in and out.

@ebsmothers
Copy link
Contributor Author

@joecummings and @SalmanMohammadi thanks for the comments.

ac_conditions is maybe too loose of a contract to allow into this public API. I might prefer something like an CheckpointingPolicy class, which is more inline with how FSDP2 kinda works and torchtitan uses it as well.

Can you clarify what part of CheckpointPolicy you'd like for us to adopt? Because there is the piece of must vs prefer save/compute which seems more granular than what we need. At least the way torchtitan uses it is for SAC on individual ops. I am open to having us support that but considered it more of a nice-to-have (cause yeah I agree my proposal will not really work for that case). Actually I think this is the most important detail before moving further on design, so question for the room.. do we want to support op-level SAC?

I think you will need something like this

Yeah this should be the same as this code pointer, right? Although I guess this way some hooks fire as well.

Since this RFC relates to our APIs around activation checkpointing, I wonder what the UX would be like for using and customizing this component in recipes? A couple thoughts I had:

  1. This is a relatively niche feature (is it?), maybe we don't need a way to configure it from the config. If a user wants to make a change here, they're forking/copying the recipe file.

or

  1. We want people to be able to configure this easily, such as something like:
ac_conditions:
  - torchtune.training.activations.selective_active_checkpointing_condition
     - k: 4
  - my_repo.only_checkpointing_attention_projections

This would involve offering some sensible conditions from torchtune, and making it easy to swap them in and out.

Imo if we are gonna properly support SAC it should be doable from the config. I like your example in (2), though I think we may need to support partial instantiate to get the every k layers example to work cleanly. And separately I wanna figure out how we can keep the bool flag to for vanilla checkpoint-every-layer as our first-class citizen.

@joecummings joecummings added rfc Request for comments and removed discussion Start a discussion labels Dec 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rfc Request for comments
Projects
None yet
Development

No branches or pull requests

3 participants