Skip to content

Commit

Permalink
Add ability to enable/disable act ckpt and seq parallelism in GPT (NV…
Browse files Browse the repository at this point in the history
…IDIA#6327)

* Add ability to enable/disable act ckpt and seq parallelism

Signed-off-by: Markel Sanz Ausin <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove num_micro_batches_with_partial_activation_checkpoints

Signed-off-by: Markel Sanz Ausin <[email protected]>

* Added property to self.model and added restore/reset config values.

Signed-off-by: Markel Sanz Ausin <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use self.model property

Signed-off-by: Markel Sanz Ausin <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Removed original_act_ckpt

Signed-off-by: Markel Sanz Ausin <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add docstrings to reset/restore act ckpt

Signed-off-by: Markel Sanz Ausin <[email protected]>

* Property removed from self.model and replaced with get_gpt_module_list function.

Signed-off-by: Markel Sanz Ausin <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Markel Sanz Ausin <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
3 people authored and hsiehjackson committed Jun 2, 2023
1 parent 40f3628 commit e5397a3
Showing 1 changed file with 80 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True)

def get_gpt_module_list(self):
if isinstance(self.model, list):
return [model.module if isinstance(model, Float16Module) else model for model in self.model]
elif isinstance(self.model, Float16Module):
return [self.model.module]
else:
return [self.model]

def set_inference_config(self, inference_config):
self._inference_config = inference_config

Expand Down Expand Up @@ -1002,3 +1010,75 @@ def parameters(self):
return itertools.chain.from_iterable(module.parameters() for module in self.model)
else:
return self.model.parameters()

def _reset_activation_checkpointing_args(self):
""" Disables activation checkpointing completely and saves the values so that
_restore_activation_checkpointing_args can restore them later. This function must always be
called before _restore_activation_checkpointing_args.
"""
# Store values to restore them later.
self.last_activations_checkpoint_granularity = self.cfg.activations_checkpoint_granularity
self.last_activations_checkpoint_method = self.cfg.activations_checkpoint_method
self.last_activations_checkpoint_num_layers = self.cfg.activations_checkpoint_num_layers
self.last_activations_checkpoint_layers_per_pipeline = self.cfg.activations_checkpoint_layers_per_pipeline

# Reset config values. Needed for calling generate.
self.cfg.activations_checkpoint_granularity = None
self.cfg.activations_checkpoint_method = None
self.cfg.activations_checkpoint_num_layers = None
self.cfg.activations_checkpoint_layers_per_pipeline = None

# Reset model parameters.
for module in self.get_gpt_module_list():
module.language_model.encoder.activations_checkpoint_granularity = None
module.language_model.encoder.activations_checkpoint_method = None
module.language_model.encoder.activations_checkpoint_num_layers = None
module.language_model.encoder.activations_checkpoint_layers_per_pipeline = None

def _restore_activation_checkpointing_args(self):
""" Restores the activation checkpointing parameters using the values saved by
_reset_activation_checkpointing_args. This function must never be called before
_reset_activation_checkpointing_args.
"""
# Restore config values.
self.cfg.activations_checkpoint_granularity = self.last_checkpointing_granularity
self.cfg.activations_checkpoint_method = self.last_checkpointing_method
self.cfg.activations_checkpoint_num_layers = self.last_checkpointing_num_layers
self.cfg.activations_checkpoint_layers_per_pipeline = self.last_activations_checkpoint_layers_per_pipeline

# Restore model parameters.
for module in self.get_gpt_module_list():
module.language_model.encoder.activations_checkpoint_granularity = self.last_checkpointing_granularity
module.language_model.encoder.activations_checkpoint_method = self.last_checkpointing_method
module.language_model.encoder.activations_checkpoint_num_layers = self.last_checkpointing_num_layers
module.language_model.encoder.activations_checkpoint_layers_per_pipeline = (
self.last_activations_checkpoint_layers_per_pipeline
)

def _reset_sequence_parallelism_args(self):
""" Disables sequence parallelism completely and saves the values so that
_restore_sequence_parallelism_args can restore them later. This function must always be
called before _restore_sequence_parallelism_args.
"""
# Store values to restore them later.
self.last_sequence_parallel = self.cfg.sequence_parallel

# Reset config values. Needed for calling generate.
self.cfg.sequence_parallel = None

# Reset model parameters.

for module in self.get_gpt_module_list():
module.language_model.encoder.sequence_parallel = None

def _restore_sequence_parallelism_args(self):
""" Restores the sequence parallelism parameters using the values saved by
_reset_sequence_parallelism_args. This function must never be called before
_reset_sequence_parallelism_args.
"""
# Restore config values.
self.cfg.sequence_parallel = self.last_sequence_parallel

# Restore model parameters.
for module in self.get_gpt_module_list():
module.language_model.encoder.sequence_parallel = self.last_sequence_parallel

0 comments on commit e5397a3

Please sign in to comment.