Skip to content

Selective Activation Checkpointing with LayerNormMLP #623

@denizokt

Description

@denizokt

Hi all,

I was wondering whether it is possible to do selective activation checkpointing with the LayerNormMLP where we only recompute FFN1 and not FFN2, therefore not having to save the ffn1_out and gelu_out activations (the largest memory activations).

This has been done in OPT, https://github.com/facebookresearch/metaseq/blob/f7ffa5fd61cf90f498a36d365c13dd7f1a912ff7/metaseq/modules/sequence_parallel_transformer_layer.py#L250C20-L250C33
so I wonder if it is possible to do in TransformerEngine, because it would be awesome to use it with FP8!

Thank you!

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions