Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torchtitan/components/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,21 @@ def build_mse_loss(job_config: JobConfig, **kwargs):
logger.info("Compiling the loss function with torch.compile")
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend)
return loss_fn


def moe_loss(
pred: tuple[torch.Tensor, torch.Tensor] | torch.Tensor,
labels: torch.Tensor,
loss_fn: LossFunction,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could have a consistent API with other loss function - Taking job_config as input , and plug-in the loss like other loss Function in TrainSpec:

build_loss_fn=build_cross_entropy_loss,
.

So that we could avoid the change in train.py. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I think we can use a new build_loss_fn for models that possibly have moe. Or we can update build_cross_entropy_loss by checking whether moe is enabled from config here

if job_config.compile.enable and "loss" in job_config.compile.components:
.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean smth like build_multiple_loss? or we do build_ce_and_moe_loss and build_mse_and_moe_loss?

) -> torch.Tensor:
"""Sequence-wise auxiliary load balance loss function for MoE
model training.
Comment on lines +88 to +89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need to be specific on sequence-wise or batch-wise here right?

"""
if isinstance(pred, tuple):
pred, load_balance_loss = pred
loss = loss_fn(pred, labels)
# USE STE to make the magnitude of loss remain the same
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can be more explicit here.

Suggested change
# USE STE to make the magnitude of loss remain the same
# Add auxiliary loss to the computation graph for gradients in the backward pass,
# but cancel out its numeric value so the forward pass only logs language model task loss.

loss = loss + (load_balance_loss - load_balance_loss.detach())
else:
loss = loss_fn(pred, labels)
return loss
15 changes: 15 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ class Metrics:
"""Whether to log metrics to Weights & Biases"""


@dataclass
class ExtraLosses:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section is specifically for MoE load balancing loss for now, do you foresee any other loss related params will be used in this section? If not, let's make the name for descriptive and specific

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followup here. Should we merge these configs to the Model dataclass?

load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise"
"""Type of load balance loss to use"""

load_balance_loss_weight: float = 0
"""Weight of load balance loss"""

load_balance_coeff: float | None = 1e-3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably rename this to loss_free_load_balance_coeff? And IIUC because it's loss free, we need to set it to none if we use loss-based load balancing, otherwise it will register a optimizer hook here:

if _should_register_moe_balancing_hook(model_parts):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both loss-free and loss-based load balancing are used simultaneously in deepseek v3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes DPSKV3 (and GLM 4.5, as i know) uses both.

load_balance_coeff is the name used in the current repo, and yes maybe we should name them properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
load_balance_coeff: float | None = 1e-3
load_balance_bias_coeff: float | None = 1e-3

"""Coefficient of bias update for aux-loss-free load balancing"""


@dataclass
class Model:
name: str = "llama3"
Expand Down Expand Up @@ -130,6 +142,9 @@ class Model:
converters have been applied.
"""

extra_losses: ExtraLosses = field(default_factory=ExtraLosses)
"""Extra losses to use"""


@dataclass
class Optimizer:
Expand Down
5 changes: 5 additions & 0 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)
self.max_seq_len = seq_len

losses_config = job_config.model.extra_losses
self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type
self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight
self.moe_args.load_balance_coeff = losses_config.load_balance_coeff
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.moe_args.load_balance_coeff = losses_config.load_balance_coeff
self.moe_args.load_balance_bias_coeff = losses_config.load_balance_bias_coeff


if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0):
logger.warning(
"Failed to use grouped mm, which is only supported on SM90 or later",
Expand Down
25 changes: 20 additions & 5 deletions torchtitan/models/deepseek_v3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
accumulated_load_balance_loss: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Expand All @@ -323,10 +324,15 @@ def forward(
"""
x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
if self.moe_enabled:
x = x + self.moe(self.ffn_norm(x))
ffn_moe_output, load_balance_loss = self.moe(self.ffn_norm(x))
accumulated_load_balance_loss = (
accumulated_load_balance_loss + load_balance_loss
)
else:
x = x + self.feed_forward(self.ffn_norm(x))
return x
ffn_moe_output = self.feed_forward(self.ffn_norm(x))

x = x + ffn_moe_output
return x, accumulated_load_balance_loss

def init_weights(self, buffer_device: torch.device):
for norm in (self.attention_norm, self.ffn_norm):
Expand Down Expand Up @@ -410,6 +416,7 @@ def get_attention_masks(
def forward(
self,
tokens: torch.Tensor,
accumulated_load_balance_loss: torch.Tensor | None = None,
attention_masks: AttentionMasksType | None = None,
):
"""
Expand All @@ -427,8 +434,16 @@ def forward(

h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens

accumulated_load_balance_loss = (
torch.zeros((), device=h.device, dtype=torch.float32)
if accumulated_load_balance_loss is None
else accumulated_load_balance_loss
)

for layer in self.layers.values():
h = layer(h, self.freqs_cis, attention_masks)
h, accumulated_load_balance_loss = layer(
h, self.freqs_cis, accumulated_load_balance_loss, attention_masks
)
h = self.norm(h) if self.norm is not None else h
output = self.output(h) if self.output is not None else h
return output
return output, accumulated_load_balance_loss
124 changes: 121 additions & 3 deletions torchtitan/models/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class MoEArgs:
top_k: int = 1
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can rename this config (here and several other places) to differentiate loss-free load balancing based on adaptive bias update and auxiliary loss based load balancing.

Suggested change
load_balance_coeff: float | None = 1e-3
load_balance_bias_coeff: float | None = 1e-3


load_balance_loss_weight: float = 0
load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise"
_debug_force_load_balance: bool = False
# if True, we force each experts get same amount of token via round-robin

Expand Down Expand Up @@ -287,7 +288,7 @@ def forward(
max=self.num_experts,
)

return top_scores, selected_experts_indices, num_tokens_per_expert
return top_scores, scores, selected_experts_indices, num_tokens_per_expert

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
Expand Down Expand Up @@ -359,6 +360,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
super().__init__()

num_experts = moe_args.num_experts
self.topk = moe_args.top_k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for nit

Suggested change
self.topk = moe_args.top_k
self.top_k = moe_args.top_k

self.num_experts = num_experts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.num_experts = num_experts

self.experts = GroupedExperts(
dim=dim,
hidden_dim=hidden_dim,
Expand Down Expand Up @@ -386,6 +389,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
# NOTE: tokens_per_expert is accumulated in the model forward pass.
# expert_bias is updated outside the model in an optimizer step pre hook
# to work with gradient accumulation.
self.load_balance_loss_weight = moe_args.load_balance_loss_weight
self.load_balance_loss_type = moe_args.load_balance_loss_type
self.load_balance_coeff = moe_args.load_balance_coeff
if self.load_balance_coeff is not None:
assert self.load_balance_coeff > 0.0
Expand Down Expand Up @@ -418,6 +423,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# num_tokens_per_expert shape (num_experts,)
(
top_scores,
scores,
selected_experts_indices,
num_tokens_per_expert,
) = self.router(x, self.expert_bias)
Expand All @@ -430,6 +436,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
self.tokens_per_expert.add_(num_tokens_per_expert)

if self.training:
if self.load_balance_loss_type == "sequence_wise":
load_balance_loss = MoE.sequence_wise_aux_loss(
scores,
selected_experts_indices.long(),
bs,
slen,
self.topk,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.topk,
self.top_k,

self.load_balance_loss_weight,
)
elif self.load_balance_loss_type == "batch_wise":
load_balance_loss = MoE.batch_wise_aux_loss(
scores,
num_tokens_per_expert,
self.topk,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.topk,
self.top_k,

self.load_balance_loss_weight,
)
else:
load_balance_loss = torch.tensor(0.0, device=out.device, dtype=out.dtype)

# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
# num_tokens_per_expert shape (num_experts,)
# NOTE: the reason we need to compute num_tokens_per_expert again is:
Expand Down Expand Up @@ -479,7 +505,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
dim=0, index=token_indices_experts_sorted, src=routed_output
)
out = out.reshape(bs, slen, dim)
return out

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

return out, load_balance_loss

def init_weights(
self,
Expand All @@ -499,3 +526,94 @@ def init_weights(
self.expert_bias = torch.zeros(
self.experts.num_experts, dtype=torch.float32
)

@staticmethod
@torch.compile(fullgraph=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: Do we always want to compile this loss? Is it for speed purpose? Should we provide options for users to control whether they want to compile or not, like if job_config.compile.enable and "loss" in job_config.compile.components in loss.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep for speed up. Idk when we have compile + full graph will it automatically compiled or not (i would expect so)

def sequence_wise_aux_loss(
scores: torch.Tensor, # Shape: (B*S, N) - Raw Sigmoid Affinities (s_{i,t})
indices: torch.Tensor, # Shape: (B*S, K) - Selected Expert Indices
B: int, # Batch size
S: int, # Sequence length (T in the paper)
top_k: int, # K_r
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The K_r here is the same with K elsewhere in this function right? Maybe we can use a consistent notation top_k in all comments, and tell people this is K_r in the deepseek paper. Similarly we can use N to denote the number of routed experts and tell people this is N_r in the deepseek paper.

aux_loss_alpha: float, # Alpha
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for nit, just try to be consistent with wording.

Suggested change
aux_loss_alpha: float, # Alpha
aux_loss_weight: float, # Alpha

) -> torch.Tensor:
"""
Computes Sequence-Wise Auxiliary Loss (DeepSeek-V3 Equations 17-20).

Args:
scores: The dense affinity scores (s_{i,t}) for routed experts.
Should be the output of Sigmoid, shape (B*S, N).
indices: The top-k selected expert indices. Shape (B*S, K).
"""
if aux_loss_alpha <= 0:
return torch.tensor(0.0, device=scores.device, dtype=scores.dtype)

# N_r: Total number of routed experts
N = scores.size(-1)

# 1. Reshape inputs to handle each sequence separately: (B, S, N)
# This ensures we calculate P_i and f_i per sequence (Eq 20 & 18).
scores_per_seq = scores.view(B, S, N)
indices_per_seq = indices.view(B, S, top_k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used afterwards.

Suggested change
indices_per_seq = indices.view(B, S, top_k)


# 2. Eq 19: Normalize affinity scores s_{i,t} to get s'_{i,t}
# DeepSeek-V3 uses Sigmoid, so scores don't sum to 1.
# Eq 19 explicitly requires dividing by the sum of all affinities.
# denominator shape: (B, S, 1)
denominator = scores_per_seq.sum(dim=-1, keepdim=True) + 1e-20
probs_per_seq = scores_per_seq / denominator # This is s'_{i,t}

# 3. Eq 20: Calculate P_i (Average probability per expert for each sequence)
# P_i = (1/T) * sum_{t=1}^T (s'_{i,t})
# We average over the Sequence dimension (dim=1).
# P_i shape: (B, N)
P_i = probs_per_seq.mean(dim=1)

# 4. Eq 18: Calculate f_i (Fraction of tokens selecting expert i per sequence)
# f_i = (N / (K * T)) * count_i

# Flatten the top-k dimension to count hits per sequence: (B, S*K)
flat_indices_per_seq = indices_per_seq.view(B, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
flat_indices_per_seq = indices_per_seq.view(B, -1)
batch_indices_per_seq = indices.flatten(1)

selection_counts = torch.zeros((B, N), device=scores.device, dtype=scores.dtype)
src = torch.ones_like(flat_indices_per_seq, dtype=scores.dtype)
selection_counts.scatter_add_(1, flat_indices_per_seq, src)
Comment on lines +577 to +579
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me we do not need to create a new src here. We may consider using torch.bincount to save memory.

Suggested change
selection_counts = torch.zeros((B, N), device=scores.device, dtype=scores.dtype)
src = torch.ones_like(flat_indices_per_seq, dtype=scores.dtype)
selection_counts.scatter_add_(1, flat_indices_per_seq, src)
offset = (torch.arange(B, device=batch_indices_per_seq.device).unsqueeze(1) * N)
flat_indices = (batch_indices_per_seq + offset).reshape(-1)
selection_counts = torch.bincount(flat_indices, minlength=B * N).reshape(B, N)
selection_counts = selection_counts.to(dtype=scores.dtype)


# Calculate f_i for each sequence, T (tokens in sequence) is S
f_i = selection_counts * (N / (top_k * S))

# 5. Eq 17: Calculate Balance Loss
loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_alpha
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_alpha
loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_weight


return loss_per_seq.mean()

@staticmethod
@torch.compile(fullgraph=True)
def batch_wise_aux_loss(
scores: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
top_k: int,
aux_loss_alpha: float,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
aux_loss_alpha: float,
aux_loss_weight: float,

) -> torch.Tensor:
"""
Computes Batch-Wise Auxiliary Loss.
Args:
scores: Dense probabilities (BS, N).
num_tokens_per_expert: Token counts (N).
top_k: Number of experts selected per token.
aux_loss_alpha: Scaling factor for the loss.
"""
if aux_loss_alpha <= 0:
return torch.tensor(0.0, device=scores.device, dtype=scores.dtype)

# Total number of routed experts (N)
N = scores.size(1)
# Total number of tokens (T = BS * S)
T = scores.size(0)

P_i = scores.mean(dim=0)

f_i = num_tokens_per_expert.to(scores.dtype) * (N / (top_k * T))

loss = (f_i * P_i).sum() * aux_loss_alpha
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
loss = (f_i * P_i).sum() * aux_loss_alpha
loss = (f_i * P_i).sum() * aux_loss_weight


return loss
8 changes: 7 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools
import importlib
import os
import time
Expand All @@ -18,7 +19,7 @@
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.dataloader import DataloaderExhaustedError
from torchtitan.components.ft import FTManager, maybe_semi_sync_training
from torchtitan.components.loss import rescale_accumulated_loss
from torchtitan.components.loss import moe_loss, rescale_accumulated_loss
from torchtitan.components.metrics import (
build_metrics_processor,
ensure_pp_loss_visible,
Expand Down Expand Up @@ -184,6 +185,11 @@ def __init__(self, job_config: JobConfig):
job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager
)

self.loss_fn = functools.partial(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a condition here to wrap loss or not for MoE. for now all models in torchtitan only return a single output so its ok for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If subsume this moe loss wrapper into build_loss_fn we can avoid adding the logic here.

moe_loss,
loss_fn=self.loss_fn,
)

# verify batch sizes
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
Expand Down
Loading