Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flag to get attention from fusion #6049

Merged
merged 1 commit into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ model:
bias_activation_fusion: True # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function.
bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.
masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask.
get_attention_mask_from_fusion: False # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages.


# Miscellaneous
Expand Down
42 changes: 31 additions & 11 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self._nsys_profile_start_step *= grad_accum_steps
self._nsys_profile_end_step *= grad_accum_steps

self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', False)

def set_inference_config(self, inference_config):
self._inference_config = inference_config

Expand Down Expand Up @@ -507,25 +509,43 @@ def allreduce_first_last_embeddings(self):

def get_forward_output_and_loss_func(self, validation_step=False):
def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None):
# GPT3 uses only causal mask, which doesn't need attention mask
batch = next(dataloader_iter)
if parallel_state.get_pipeline_model_parallel_world_size() == 1:
batch = next(dataloader_iter)
for k in batch.keys():
batch[k] = batch[k].cuda(non_blocking=True) if k not in ['attention_mask'] else None
if self.get_attention_mask_from_fusion:
batch[k] = batch[k].cuda(non_blocking=True) if k not in ['attention_mask'] else None
else:
batch[k] = batch[k].cuda(non_blocking=True)
else:
if parallel_state.is_pipeline_first_stage():
batch = next(dataloader_iter)
# First pipeline stage needs only the tokens and position_ids
# First pipeline stage needs tokens, position_ids, and attention_mask
for k in batch.keys():
batch[k] = batch[k].cuda(non_blocking=True) if k in ['tokens', 'position_ids'] else None
if self.get_attention_mask_from_fusion:
batch[k] = batch[k].cuda(non_blocking=True) if k in ['tokens', 'position_ids'] else None
else:
batch[k] = (
batch[k].cuda(non_blocking=True)
if k in ['tokens', 'position_ids', 'attention_mask']
else None
)
elif parallel_state.is_pipeline_last_stage():
batch = next(dataloader_iter)
# Last pipeline stage needs only the labels and loss_mask
# Last pipeline stage needs the labels, loss_mask, and attention_mask
for k in batch.keys():
batch[k] = batch[k].cuda(non_blocking=True) if k in ['labels', 'loss_mask'] else None
if self.get_attention_mask_from_fusion:
batch[k] = batch[k].cuda(non_blocking=True) if k in ['labels', 'loss_mask'] else None
else:
batch[k] = (
batch[k].cuda(non_blocking=True)
if k in ['labels', 'loss_mask', 'attention_mask']
else None
)
else:
# Intermediate pipeline stage doesn't need any inputs
batch = {k: None for k in ['tokens', 'position_ids', 'attention_mask', 'labels']}
# Intermediate pipeline stage only needs attention_mask
if self.get_attention_mask_from_fusion:
batch = {k: None for k in ['tokens', 'position_ids', 'attention_mask', 'labels']}
else:
for k in batch.keys():
batch[k] = batch[k].cuda(non_blocking=True) if k in ['attention_mask'] else None

output_tensor = model(
batch['tokens'],
Expand Down