@@ -422,11 +422,11 @@ def forward_backward_step(
422422 extra_inputs = {k : v for k , v in input_dict .items () if k != "input" }
423423 # For arguments, like attention_masks, we have to put them in a separate
424424 # dict as extra_inputs are not forwarded to other stages in PP, but
425- # extra_args are.
426- extra_args = {}
425+ # extra_kwargs are.
426+ extra_kwargs = {}
427427
428428 if getattr (self .model_args , "use_flex_attn" , False ):
429- extra_args ["attention_masks" ] = model_parts [0 ].get_attention_masks (
429+ extra_kwargs ["attention_masks" ] = model_parts [0 ].get_attention_masks (
430430 input_batch = inputs ,
431431 tokenizer = self .tokenizer ,
432432 extra_inputs = extra_inputs ,
@@ -457,17 +457,15 @@ def forward_backward_step(
457457 self .pp_schedule .step (
458458 inputs ,
459459 ** extra_inputs ,
460- ** extra_args ,
460+ ** extra_kwargs ,
461461 target = targets ,
462462 losses = losses ,
463- input_batch = inputs ,
464463 )
465464 else :
466465 self .pp_schedule .step (
467- ** extra_args ,
466+ ** extra_kwargs ,
468467 target = targets ,
469468 losses = losses ,
470- input_batch = inputs ,
471469 )
472470
473471 # accumulate losses across pipeline microbatches
@@ -485,7 +483,7 @@ def forward_backward_step(
485483 with self .train_context (optional_context_parallel_ctx ):
486484 assert len (model_parts ) == 1
487485 with self .maybe_enable_amp :
488- pred = model_parts [0 ](inputs , ** extra_inputs , ** extra_args )
486+ pred = model_parts [0 ](inputs , ** extra_inputs , ** extra_kwargs )
489487 loss = self .loss_fn (pred , labels )
490488 # need to free pred before bwd to avoid peaking memory
491489 del pred
0 commit comments