diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index bedd1ce6e8..1b54193510 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -137,12 +137,9 @@ def validate( inputs, target=targets, losses=losses, - input_batch=inputs, ) else: - self.pp_schedule.eval( - target=targets, losses=losses, input_batch=inputs - ) + self.pp_schedule.eval(target=targets, losses=losses) # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 6dbe7d78f3..4f16035b6f 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -158,10 +158,10 @@ def forward_backward_step( parallel_dims = self.parallel_dims inputs = input_dict["input"] - extra_args = {} + extra_kwargs = {} if getattr(self.model_args, "use_flex_attn", False): - extra_args["attention_masks"] = model_parts[0].get_attention_masks( + extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, ) @@ -187,17 +187,15 @@ def forward_backward_step( if self.pp_has_first_stage: self.pp_schedule.step( inputs, - **extra_args, + **extra_kwargs, target=targets, losses=losses, - input_batch=inputs, ) else: self.pp_schedule.step( - **extra_args, + **extra_kwargs, target=targets, losses=losses, - input_batch=inputs, ) # accumulate losses across pipeline microbatches @@ -215,7 +213,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_args) + pred = model_parts[0](inputs, **extra_kwargs) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred diff --git a/torchtitan/experiments/vlm/model/model.py b/torchtitan/experiments/vlm/model/model.py index 712cd8058b..0f868cb93c 100644 --- a/torchtitan/experiments/vlm/model/model.py +++ b/torchtitan/experiments/vlm/model/model.py @@ -96,7 +96,6 @@ def forward( grid_thw: torch.Tensor, special_tokens: SpecialTokens, attention_masks: AttentionMasksType | None = None, - input_batch: torch.Tensor | None = None, ): # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h_BSD = self.tok_embeddings(tokens) if self.tok_embeddings else tokens diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index d5bc9b1016..3cf56eb1b2 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -411,7 +411,6 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - input_batch: torch.Tensor | None = None, ): """ Forward pass for the Transformer model. @@ -421,10 +420,6 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. - input_batch (torch.Tensor): The input batch read from the dataloader. - This will always be the input batch regardless of the pipeline stage. - This field is required for non-first PP stages to perform document - masking attention (to analyze the boundary of the document). Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 6f10719d12..124153f14c 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -478,7 +478,6 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - input_batch: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -488,10 +487,6 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. - input_batch (torch.Tensor): The input batch read from the dataloader. - This will always be the input batch regardless of the pipeline stage. - This field is required for non-first PP stages to perform document - masking attention (to analyze the boundary of the document). Returns: torch.Tensor: Output logits after applying the Transformer model. diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index 93ff4e89b0..c8241b84de 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -539,7 +539,6 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - input_batch: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -549,10 +548,6 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. - input_batch (torch.Tensor): The input batch read from the dataloader. - This will always be the input batch regardless of the pipeline stage. - This field is required for non-first PP stages to perform document - masking attention (to analyze the boundary of the document). Returns: torch.Tensor: Output logits after applying the Transformer model. diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 0fff490bf3..32c98342bb 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -470,7 +470,6 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - input_batch: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -480,10 +479,6 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. - input_batch (torch.Tensor): The input batch read from the dataloader. - This will always be the input batch regardless of the pipeline stage. - This field is required for non-first PP stages to perform document - masking attention (to analyze the boundary of the document). Returns: torch.Tensor: Output logits after applying the Transformer model. diff --git a/torchtitan/train.py b/torchtitan/train.py index 1d5e0e500a..8fd2f8a86e 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -422,11 +422,11 @@ def forward_backward_step( extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} # For arguments, like attention_masks, we have to put them in a separate # dict as extra_inputs are not forwarded to other stages in PP, but - # extra_args are. - extra_args = {} + # extra_kwargs are. + extra_kwargs = {} if getattr(self.model_args, "use_flex_attn", False): - extra_args["attention_masks"] = model_parts[0].get_attention_masks( + extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, extra_inputs=extra_inputs, @@ -457,17 +457,15 @@ def forward_backward_step( self.pp_schedule.step( inputs, **extra_inputs, - **extra_args, + **extra_kwargs, target=targets, losses=losses, - input_batch=inputs, ) else: self.pp_schedule.step( - **extra_args, + **extra_kwargs, target=targets, losses=losses, - input_batch=inputs, ) # accumulate losses across pipeline microbatches @@ -485,7 +483,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_inputs, **extra_args) + pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory del pred