-
Notifications
You must be signed in to change notification settings - Fork 611
[moe] brings batch/sequence-wise load balance loss #2061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -78,3 +78,21 @@ def build_mse_loss(job_config: JobConfig, **kwargs): | |||||||
| logger.info("Compiling the loss function with torch.compile") | ||||||||
| loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) | ||||||||
| return loss_fn | ||||||||
|
|
||||||||
|
|
||||||||
| def moe_loss( | ||||||||
| pred: tuple[torch.Tensor, torch.Tensor] | torch.Tensor, | ||||||||
| labels: torch.Tensor, | ||||||||
| loss_fn: LossFunction, | ||||||||
| ) -> torch.Tensor: | ||||||||
| """Sequence-wise auxiliary load balance loss function for MoE | ||||||||
| model training. | ||||||||
|
Comment on lines
+88
to
+89
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do not need to be specific on sequence-wise or batch-wise here right? |
||||||||
| """ | ||||||||
| if isinstance(pred, tuple): | ||||||||
| pred, load_balance_loss = pred | ||||||||
| loss = loss_fn(pred, labels) | ||||||||
| # USE STE to make the magnitude of loss remain the same | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can be more explicit here.
Suggested change
|
||||||||
| loss = loss + (load_balance_loss - load_balance_loss.detach()) | ||||||||
| else: | ||||||||
| loss = loss_fn(pred, labels) | ||||||||
| return loss | ||||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -97,6 +97,18 @@ class Metrics: | |||||||
| """Whether to log metrics to Weights & Biases""" | ||||||||
|
|
||||||||
|
|
||||||||
| @dataclass | ||||||||
| class ExtraLosses: | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This section is specifically for MoE load balancing loss for now, do you foresee any other loss related params will be used in this section? If not, let's make the name for descriptive and specific
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Followup here. Should we merge these configs to the |
||||||||
| load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" | ||||||||
| """Type of load balance loss to use""" | ||||||||
|
|
||||||||
| load_balance_loss_weight: float = 0 | ||||||||
| """Weight of load balance loss""" | ||||||||
|
|
||||||||
| load_balance_coeff: float | None = 1e-3 | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably rename this to torchtitan/torchtitan/components/optimizer.py Line 411 in 58fa181
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think both loss-free and loss-based load balancing are used simultaneously in deepseek v3.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes DPSKV3 (and GLM 4.5, as i know) uses both.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| """Coefficient of bias update for aux-loss-free load balancing""" | ||||||||
|
|
||||||||
|
|
||||||||
| @dataclass | ||||||||
| class Model: | ||||||||
| name: str = "llama3" | ||||||||
|
|
@@ -130,6 +142,9 @@ class Model: | |||||||
| converters have been applied. | ||||||||
| """ | ||||||||
|
|
||||||||
| extra_losses: ExtraLosses = field(default_factory=ExtraLosses) | ||||||||
| """Extra losses to use""" | ||||||||
|
|
||||||||
|
|
||||||||
| @dataclass | ||||||||
| class Optimizer: | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -95,6 +95,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: | |||||
| ) | ||||||
| self.max_seq_len = seq_len | ||||||
|
|
||||||
| losses_config = job_config.model.extra_losses | ||||||
| self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type | ||||||
| self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight | ||||||
| self.moe_args.load_balance_coeff = losses_config.load_balance_coeff | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): | ||||||
| logger.warning( | ||||||
| "Failed to use grouped mm, which is only supported on SM90 or later", | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -30,7 +30,8 @@ class MoEArgs: | |||||||||||||||
| top_k: int = 1 | ||||||||||||||||
| use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation | ||||||||||||||||
| load_balance_coeff: float | None = 1e-3 | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can rename this config (here and several other places) to differentiate loss-free load balancing based on adaptive bias update and auxiliary loss based load balancing.
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| load_balance_loss_weight: float = 0 | ||||||||||||||||
| load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" | ||||||||||||||||
| _debug_force_load_balance: bool = False | ||||||||||||||||
| # if True, we force each experts get same amount of token via round-robin | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -287,7 +288,7 @@ def forward( | |||||||||||||||
| max=self.num_experts, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| return top_scores, selected_experts_indices, num_tokens_per_expert | ||||||||||||||||
| return top_scores, scores, selected_experts_indices, num_tokens_per_expert | ||||||||||||||||
|
|
||||||||||||||||
| def init_weights(self, init_std: float): | ||||||||||||||||
| nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) | ||||||||||||||||
|
|
@@ -359,6 +360,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): | |||||||||||||||
| super().__init__() | ||||||||||||||||
|
|
||||||||||||||||
| num_experts = moe_args.num_experts | ||||||||||||||||
| self.topk = moe_args.top_k | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for nit
Suggested change
|
||||||||||||||||
| self.num_experts = num_experts | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
| self.experts = GroupedExperts( | ||||||||||||||||
| dim=dim, | ||||||||||||||||
| hidden_dim=hidden_dim, | ||||||||||||||||
|
|
@@ -386,6 +389,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): | |||||||||||||||
| # NOTE: tokens_per_expert is accumulated in the model forward pass. | ||||||||||||||||
| # expert_bias is updated outside the model in an optimizer step pre hook | ||||||||||||||||
| # to work with gradient accumulation. | ||||||||||||||||
| self.load_balance_loss_weight = moe_args.load_balance_loss_weight | ||||||||||||||||
| self.load_balance_loss_type = moe_args.load_balance_loss_type | ||||||||||||||||
| self.load_balance_coeff = moe_args.load_balance_coeff | ||||||||||||||||
| if self.load_balance_coeff is not None: | ||||||||||||||||
| assert self.load_balance_coeff > 0.0 | ||||||||||||||||
|
|
@@ -418,6 +423,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||||||||||||
| # num_tokens_per_expert shape (num_experts,) | ||||||||||||||||
| ( | ||||||||||||||||
| top_scores, | ||||||||||||||||
| scores, | ||||||||||||||||
| selected_experts_indices, | ||||||||||||||||
| num_tokens_per_expert, | ||||||||||||||||
| ) = self.router(x, self.expert_bias) | ||||||||||||||||
|
|
@@ -430,6 +436,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||||||||||||
| with torch.no_grad(): | ||||||||||||||||
| self.tokens_per_expert.add_(num_tokens_per_expert) | ||||||||||||||||
|
|
||||||||||||||||
| if self.training: | ||||||||||||||||
| if self.load_balance_loss_type == "sequence_wise": | ||||||||||||||||
| load_balance_loss = MoE.sequence_wise_aux_loss( | ||||||||||||||||
| scores, | ||||||||||||||||
| selected_experts_indices.long(), | ||||||||||||||||
| bs, | ||||||||||||||||
| slen, | ||||||||||||||||
| self.topk, | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
| self.load_balance_loss_weight, | ||||||||||||||||
| ) | ||||||||||||||||
| elif self.load_balance_loss_type == "batch_wise": | ||||||||||||||||
| load_balance_loss = MoE.batch_wise_aux_loss( | ||||||||||||||||
| scores, | ||||||||||||||||
| num_tokens_per_expert, | ||||||||||||||||
| self.topk, | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
| self.load_balance_loss_weight, | ||||||||||||||||
| ) | ||||||||||||||||
| else: | ||||||||||||||||
| load_balance_loss = torch.tensor(0.0, device=out.device, dtype=out.dtype) | ||||||||||||||||
|
|
||||||||||||||||
| # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) | ||||||||||||||||
| # num_tokens_per_expert shape (num_experts,) | ||||||||||||||||
| # NOTE: the reason we need to compute num_tokens_per_expert again is: | ||||||||||||||||
|
|
@@ -479,7 +505,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||||||||||||
| dim=0, index=token_indices_experts_sorted, src=routed_output | ||||||||||||||||
| ) | ||||||||||||||||
| out = out.reshape(bs, slen, dim) | ||||||||||||||||
| return out | ||||||||||||||||
|
|
||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
| return out, load_balance_loss | ||||||||||||||||
|
|
||||||||||||||||
| def init_weights( | ||||||||||||||||
| self, | ||||||||||||||||
|
|
@@ -499,3 +526,94 @@ def init_weights( | |||||||||||||||
| self.expert_bias = torch.zeros( | ||||||||||||||||
| self.experts.num_experts, dtype=torch.float32 | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| @staticmethod | ||||||||||||||||
| @torch.compile(fullgraph=True) | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. n00b q: Do we always want to compile this loss? Is it for speed purpose? Should we provide options for users to control whether they want to compile or not, like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep for speed up. Idk when we have compile + full graph will it automatically compiled or not (i would expect so) |
||||||||||||||||
| def sequence_wise_aux_loss( | ||||||||||||||||
| scores: torch.Tensor, # Shape: (B*S, N) - Raw Sigmoid Affinities (s_{i,t}) | ||||||||||||||||
| indices: torch.Tensor, # Shape: (B*S, K) - Selected Expert Indices | ||||||||||||||||
| B: int, # Batch size | ||||||||||||||||
| S: int, # Sequence length (T in the paper) | ||||||||||||||||
| top_k: int, # K_r | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The K_r here is the same with K elsewhere in this function right? Maybe we can use a consistent notation top_k in all comments, and tell people this is K_r in the deepseek paper. Similarly we can use N to denote the number of routed experts and tell people this is N_r in the deepseek paper. |
||||||||||||||||
| aux_loss_alpha: float, # Alpha | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for nit, just try to be consistent with wording.
Suggested change
|
||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||
| """ | ||||||||||||||||
| Computes Sequence-Wise Auxiliary Loss (DeepSeek-V3 Equations 17-20). | ||||||||||||||||
|
|
||||||||||||||||
| Args: | ||||||||||||||||
| scores: The dense affinity scores (s_{i,t}) for routed experts. | ||||||||||||||||
| Should be the output of Sigmoid, shape (B*S, N). | ||||||||||||||||
| indices: The top-k selected expert indices. Shape (B*S, K). | ||||||||||||||||
| """ | ||||||||||||||||
| if aux_loss_alpha <= 0: | ||||||||||||||||
| return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) | ||||||||||||||||
|
|
||||||||||||||||
| # N_r: Total number of routed experts | ||||||||||||||||
| N = scores.size(-1) | ||||||||||||||||
|
|
||||||||||||||||
| # 1. Reshape inputs to handle each sequence separately: (B, S, N) | ||||||||||||||||
| # This ensures we calculate P_i and f_i per sequence (Eq 20 & 18). | ||||||||||||||||
| scores_per_seq = scores.view(B, S, N) | ||||||||||||||||
| indices_per_seq = indices.view(B, S, top_k) | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not used afterwards.
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| # 2. Eq 19: Normalize affinity scores s_{i,t} to get s'_{i,t} | ||||||||||||||||
| # DeepSeek-V3 uses Sigmoid, so scores don't sum to 1. | ||||||||||||||||
| # Eq 19 explicitly requires dividing by the sum of all affinities. | ||||||||||||||||
| # denominator shape: (B, S, 1) | ||||||||||||||||
| denominator = scores_per_seq.sum(dim=-1, keepdim=True) + 1e-20 | ||||||||||||||||
| probs_per_seq = scores_per_seq / denominator # This is s'_{i,t} | ||||||||||||||||
|
|
||||||||||||||||
| # 3. Eq 20: Calculate P_i (Average probability per expert for each sequence) | ||||||||||||||||
| # P_i = (1/T) * sum_{t=1}^T (s'_{i,t}) | ||||||||||||||||
| # We average over the Sequence dimension (dim=1). | ||||||||||||||||
| # P_i shape: (B, N) | ||||||||||||||||
| P_i = probs_per_seq.mean(dim=1) | ||||||||||||||||
|
|
||||||||||||||||
| # 4. Eq 18: Calculate f_i (Fraction of tokens selecting expert i per sequence) | ||||||||||||||||
| # f_i = (N / (K * T)) * count_i | ||||||||||||||||
|
|
||||||||||||||||
| # Flatten the top-k dimension to count hits per sequence: (B, S*K) | ||||||||||||||||
| flat_indices_per_seq = indices_per_seq.view(B, -1) | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
| selection_counts = torch.zeros((B, N), device=scores.device, dtype=scores.dtype) | ||||||||||||||||
| src = torch.ones_like(flat_indices_per_seq, dtype=scores.dtype) | ||||||||||||||||
| selection_counts.scatter_add_(1, flat_indices_per_seq, src) | ||||||||||||||||
|
Comment on lines
+577
to
+579
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems to me we do not need to create a new
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| # Calculate f_i for each sequence, T (tokens in sequence) is S | ||||||||||||||||
| f_i = selection_counts * (N / (top_k * S)) | ||||||||||||||||
|
|
||||||||||||||||
| # 5. Eq 17: Calculate Balance Loss | ||||||||||||||||
| loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_alpha | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| return loss_per_seq.mean() | ||||||||||||||||
|
|
||||||||||||||||
| @staticmethod | ||||||||||||||||
| @torch.compile(fullgraph=True) | ||||||||||||||||
| def batch_wise_aux_loss( | ||||||||||||||||
| scores: torch.Tensor, | ||||||||||||||||
| num_tokens_per_expert: torch.Tensor, | ||||||||||||||||
| top_k: int, | ||||||||||||||||
| aux_loss_alpha: float, | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||
| """ | ||||||||||||||||
| Computes Batch-Wise Auxiliary Loss. | ||||||||||||||||
| Args: | ||||||||||||||||
| scores: Dense probabilities (BS, N). | ||||||||||||||||
| num_tokens_per_expert: Token counts (N). | ||||||||||||||||
| top_k: Number of experts selected per token. | ||||||||||||||||
| aux_loss_alpha: Scaling factor for the loss. | ||||||||||||||||
| """ | ||||||||||||||||
| if aux_loss_alpha <= 0: | ||||||||||||||||
| return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) | ||||||||||||||||
|
|
||||||||||||||||
| # Total number of routed experts (N) | ||||||||||||||||
| N = scores.size(1) | ||||||||||||||||
| # Total number of tokens (T = BS * S) | ||||||||||||||||
| T = scores.size(0) | ||||||||||||||||
|
|
||||||||||||||||
| P_i = scores.mean(dim=0) | ||||||||||||||||
|
|
||||||||||||||||
| f_i = num_tokens_per_expert.to(scores.dtype) * (N / (top_k * T)) | ||||||||||||||||
|
|
||||||||||||||||
| loss = (f_i * P_i).sum() * aux_loss_alpha | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| return loss | ||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import functools | ||
| import importlib | ||
| import os | ||
| import time | ||
|
|
@@ -18,7 +19,7 @@ | |
| from torchtitan.components.checkpoint import CheckpointManager | ||
| from torchtitan.components.dataloader import DataloaderExhaustedError | ||
| from torchtitan.components.ft import FTManager, maybe_semi_sync_training | ||
| from torchtitan.components.loss import rescale_accumulated_loss | ||
| from torchtitan.components.loss import moe_loss, rescale_accumulated_loss | ||
| from torchtitan.components.metrics import ( | ||
| build_metrics_processor, | ||
| ensure_pp_loss_visible, | ||
|
|
@@ -184,6 +185,11 @@ def __init__(self, job_config: JobConfig): | |
| job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager | ||
| ) | ||
|
|
||
| self.loss_fn = functools.partial( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can add a condition here to wrap loss or not for MoE. for now all models in torchtitan only return a single output so its ok for now
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If subsume this |
||
| moe_loss, | ||
| loss_fn=self.loss_fn, | ||
| ) | ||
|
|
||
| # verify batch sizes | ||
| global_batch_size = job_config.training.global_batch_size | ||
| if global_batch_size < 0: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could have a consistent API with other loss function - Taking
job_configas input , and plug-in the loss like other loss Function in TrainSpec:torchtitan/torchtitan/models/deepseek_v3/__init__.py
Line 170 in 58fa181
So that we could avoid the change in
train.py. WDYT?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I think we can use a new
build_loss_fnfor models that possibly have moe. Or we can updatebuild_cross_entropy_lossby checking whether moe is enabled from config heretorchtitan/torchtitan/components/loss.py
Line 29 in ad9f188
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean smth like
build_multiple_loss? or we dobuild_ce_and_moe_lossandbuild_mse_and_moe_loss?