Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,9 @@ def validate(
inputs,
target=targets,
losses=losses,
input_batch=inputs,
)
else:
self.pp_schedule.eval(
target=targets, losses=losses, input_batch=inputs
)
self.pp_schedule.eval(target=targets, losses=losses)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
Expand Down
12 changes: 5 additions & 7 deletions torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ def forward_backward_step(
parallel_dims = self.parallel_dims

inputs = input_dict["input"]
extra_args = {}
extra_kwargs = {}

if getattr(self.model_args, "use_flex_attn", False):
extra_args["attention_masks"] = model_parts[0].get_attention_masks(
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
)
Expand All @@ -187,17 +187,15 @@ def forward_backward_step(
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs,
**extra_args,
**extra_kwargs,
target=targets,
losses=losses,
input_batch=inputs,
)
else:
self.pp_schedule.step(
**extra_args,
**extra_kwargs,
target=targets,
losses=losses,
input_batch=inputs,
)

# accumulate losses across pipeline microbatches
Expand All @@ -215,7 +213,7 @@ def forward_backward_step(
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs, **extra_args)
pred = model_parts[0](inputs, **extra_kwargs)
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
Expand Down
1 change: 0 additions & 1 deletion torchtitan/experiments/vlm/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def forward(
grid_thw: torch.Tensor,
special_tokens: SpecialTokens,
attention_masks: AttentionMasksType | None = None,
input_batch: torch.Tensor | None = None,
):
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
h_BSD = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/models/deepseek_v3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@ def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
input_batch: torch.Tensor | None = None,
):
"""
Forward pass for the Transformer model.
Expand All @@ -421,10 +420,6 @@ def forward(
If pipeline parallelism is enabled, this will be the input token indices
for the ranks on the first pipeline stage. This will be the activation of the
previous pipeline stage if the current rank is not on the first stage.
input_batch (torch.Tensor): The input batch read from the dataloader.
This will always be the input batch regardless of the pipeline stage.
This field is required for non-first PP stages to perform document
masking attention (to analyze the boundary of the document).

Returns:
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/models/llama3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,6 @@ def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
input_batch: torch.Tensor | None = None,
):
"""
Perform a forward pass through the Transformer model.
Expand All @@ -488,10 +487,6 @@ def forward(
If pipeline parallelism is enabled, this will be the input token indices
for the ranks on the first pipeline stage. This will be the activation of the
previous pipeline stage if the current rank is not on the first stage.
input_batch (torch.Tensor): The input batch read from the dataloader.
This will always be the input batch regardless of the pipeline stage.
This field is required for non-first PP stages to perform document
masking attention (to analyze the boundary of the document).

Returns:
torch.Tensor: Output logits after applying the Transformer model.
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/models/llama4/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,6 @@ def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
input_batch: torch.Tensor | None = None,
):
"""
Perform a forward pass through the Transformer model.
Expand All @@ -549,10 +548,6 @@ def forward(
If pipeline parallelism is enabled, this will be the input token indices
for the ranks on the first pipeline stage. This will be the activation of the
previous pipeline stage if the current rank is not on the first stage.
input_batch (torch.Tensor): The input batch read from the dataloader.
This will always be the input batch regardless of the pipeline stage.
This field is required for non-first PP stages to perform document
masking attention (to analyze the boundary of the document).

Returns:
torch.Tensor: Output logits after applying the Transformer model.
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/models/qwen3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,6 @@ def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
input_batch: torch.Tensor | None = None,
):
"""
Perform a forward pass through the Transformer model.
Expand All @@ -480,10 +479,6 @@ def forward(
If pipeline parallelism is enabled, this will be the input token indices
for the ranks on the first pipeline stage. This will be the activation of the
previous pipeline stage if the current rank is not on the first stage.
input_batch (torch.Tensor): The input batch read from the dataloader.
This will always be the input batch regardless of the pipeline stage.
This field is required for non-first PP stages to perform document
masking attention (to analyze the boundary of the document).

Returns:
torch.Tensor: Output logits after applying the Transformer model.
Expand Down
14 changes: 6 additions & 8 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,11 @@ def forward_backward_step(
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
# For arguments, like attention_masks, we have to put them in a separate
# dict as extra_inputs are not forwarded to other stages in PP, but
# extra_args are.
extra_args = {}
# extra_kwargs are.
extra_kwargs = {}

if getattr(self.model_args, "use_flex_attn", False):
extra_args["attention_masks"] = model_parts[0].get_attention_masks(
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
extra_inputs=extra_inputs,
Expand Down Expand Up @@ -457,17 +457,15 @@ def forward_backward_step(
self.pp_schedule.step(
inputs,
**extra_inputs,
**extra_args,
**extra_kwargs,
target=targets,
losses=losses,
input_batch=inputs,
)
else:
self.pp_schedule.step(
**extra_args,
**extra_kwargs,
target=targets,
losses=losses,
input_batch=inputs,
)

# accumulate losses across pipeline microbatches
Expand All @@ -485,7 +483,7 @@ def forward_backward_step(
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs, **extra_inputs, **extra_args)
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
loss = self.loss_fn(pred, labels)
# need to free pred before bwd to avoid peaking memory
del pred
Expand Down