Skip to content

Commit

Permalink
Pre-compute max_seqlen and cu_seqlens_argmin in all model-parallel ca…
Browse files Browse the repository at this point in the history
…ses (NVIDIA#8222)

Signed-off-by: Sangkug Lym <[email protected]>
  • Loading branch information
erhoo82 committed Jan 23, 2024
1 parent 8abdb25 commit 40fb2ce
Showing 1 changed file with 2 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -882,17 +882,14 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_

# Transfer needed data to GPU
required_keys = set()
max_seqlen, cu_seqlens_argmin = None, None
max_seqlen = batch.pop('max_seqlen').squeeze() if 'max_seqlen' in batch else None
cu_seqlens_argmin = batch.pop('cu_seqlens_argmin') if 'cu_seqlens_argmin' in batch else None
if parallel_state.get_pipeline_model_parallel_world_size() == 1:
required_keys.update(batch.keys())
else:
required_keys.add('attention_mask')
if 'cu_seqlens' in batch:
required_keys.add('cu_seqlens')
if 'max_seqlen' in batch:
max_seqlen = batch['max_seqlen'].squeeze()
if 'cu_seqlens_argmin' in batch:
cu_seqlens_argmin = batch['cu_seqlens_argmin']
if parallel_state.is_pipeline_first_stage():
required_keys.update(('tokens', 'position_ids'))
if parallel_state.is_pipeline_last_stage():
Expand Down

0 comments on commit 40fb2ce

Please sign in to comment.