Skip to content

[Fix] Fix nan error for large scale ep#12866

Merged
Fridge003 merged 3 commits intomainfrom
baizhou/fix-dp
Nov 11, 2025
Merged

[Fix] Fix nan error for large scale ep#12866
Fridge003 merged 3 commits intomainfrom
baizhou/fix-dp

Conversation

@Fridge003
Copy link
Collaborator

@Fridge003 Fridge003 commented Nov 8, 2025

Motivation

Part of #12293

This bug is introduced in #10874, which will wrongly remove some of the redundant experts from logical_to_all_physical_map (e.g., should be [0, 256] for logical expert 0, but wrongly set to [0, -1])

This only happens on the first and the last node. On these two nodes the values of w13_input_scale will become nan

[2025-11-08 02:52:56 DP2 TP2 EP2] w13_input_scale values: tensor([nan, nan, nan, nan, nan, nan], device='cuda:2', dtype=torch.float32)
[2025-11-08 02:52:56 DP1 TP1 EP1] w13_input_scale values: tensor([nan, nan, nan, nan, nan, nan], device='cuda:1', dtype=torch.float32)

Later it hits errors like

  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/modelopt_quant.py", line 1460, in process_weights_after_loading
    assert torch.all(w13_input_scale == w13_input_scale[0])

After this fix it can be solved.

Reproduction:
GB200, latest main branch
https://gist.github.com/kaixih/32bdc4fec4feabe9305d1acb2e1f96db

Modifications

Fix this by keeping the redundant experts on these nodes.

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Fridge003, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses and resolves a NaN error that can occur in large-scale expert parallelism (EP) setups. The core change refines the assignment mechanism for logical experts to physical experts, specifically by introducing a check to prevent the redundant mapping of an already assigned physical expert. This enhancement aims to improve the stability and correctness of computations in distributed environments by ensuring consistent expert allocation.

Highlights

  • Preventing Duplicate Expert Assignments: Modified the expert mapping logic within _compute_logical_to_all_physical_map to ensure that a nearest_expert is only assigned if it's not already present in the list of mapped physical experts. This prevents potential redundant or incorrect assignments that could lead to NaN errors in large-scale expert parallelism.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a NaN error in large-scale expert parallelism by modifying how the nearest physical expert is handled. The previous implementation could reduce expert redundancy, which might have been the source of the issue. However, the new implementation introduces a logical flaw that renders the nearest expert handling logic ineffective. My review identifies this issue, corrects a typo, and proposes a more robust solution to correctly prioritize the nearest expert without losing redundancy.

Comment on lines +358 to +365
mapped_phsical_experts = logical_to_all_physical_map[layer_id][
logical_expert_id
]
if (
nearest_expert != -1
and nearest_expert not in mapped_phsical_experts
):
mapped_phsical_experts[0] = nearest_expert
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are a couple of issues in this block:

  1. There's a typo in the variable name mapped_phsical_experts. It should be mapped_physical_experts.
  2. The condition nearest_expert not in mapped_phsical_experts will always be false when nearest_expert != -1, because the _find_nearest_expert function returns an expert from the candidate list (mapped_phsical_experts). This makes the if block dead code. While this might inadvertently fix the original issue by disabling this logic, it's not a clean solution.
  3. If the intention is to prioritize the nearest_expert by moving it to the front of the list, simply replacing the first element with mapped_phsical_experts[0] = nearest_expert could create duplicates if nearest_expert is already in the list at a different position.

A better approach is to move the nearest_expert to the front of the list. This preserves all assigned experts and correctly prioritizes the nearest one.

Suggested change
mapped_phsical_experts = logical_to_all_physical_map[layer_id][
logical_expert_id
]
if (
nearest_expert != -1
and nearest_expert not in mapped_phsical_experts
):
mapped_phsical_experts[0] = nearest_expert
mapped_physical_experts = logical_to_all_physical_map[layer_id][
logical_expert_id
]
if nearest_expert != -1:
if nearest_expert in mapped_physical_experts:
mapped_physical_experts.remove(nearest_expert)
mapped_physical_experts.insert(0, nearest_expert)

nearest_expert != -1
and nearest_expert not in mapped_physical_experts
):
mapped_physical_experts[0] = nearest_expert
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why replace the first expert with the nearest?

Copy link
Collaborator Author

@Fridge003 Fridge003 Nov 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No specific reason. I'm checking this with author of the breaking PR.
It's at least better than the original logic

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Just curious, the original logic seemed to replace all mappings with the nearest expert that is not -1, but the new logic seems to replace only the first mapping.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original logic expelled some needed experts, causing this bug

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridge003 what does "some needed experts" mean? Is the first expert enough?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, on EP rank 0, [0, 256] is initialized. But the prior logic will change it to [0, -1], then expert 256 is missing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If [0, 256] is initialized, it's possible to be changed to [256,256], will it still cause the issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will not cause this issue, since 256 is already in there?
Or do you mean we need to check whether the replaced id is unique?

@wenscarl
Copy link
Collaborator

With 10874, on 1st and last node, the map could be:

[[0], [1], [2], ..., [24, 280], [25, 281], [31, 287], [33], [34],..., [255]]

where 256 to 279 physical experts are missing.

If [0, 256] is initialized, it's possible to be changed to [256,256], will it still cause the issue?
where could 256 be placed?

@wenscarl
Copy link
Collaborator

Another question is, w13_input_scale is supposed to be independent of physical map since each expert reads in 288 values anyways.

@kaixih
Copy link
Collaborator

kaixih commented Nov 11, 2025

Another question is, w13_input_scale is supposed to be independent of physical map since each expert reads in 288 values anyways.

But does it need the map to locate the 288 -> 256 logic experts and then load the weights (including the input scale) from them?

@Fridge003
Copy link
Collaborator Author

Another question is, w13_input_scale is supposed to be independent of physical map since each expert reads in 288 values anyways.

But does it need the map to locate the 288 -> 256 logic experts and then load the weights (including the input scale) from them?

Yeah. If some experts are missing, maybe it will throw some nan values

@Fridge003
Copy link
Collaborator Author

@Fridge003
Copy link
Collaborator Author

Since CIs are all green, let's first merge this to unblock some ongoing tasks for gb200.
If there is a better solution, we can open it in a following PR
@kaixih @wenscarl @acelyc111

@Fridge003 Fridge003 merged commit 99e2580 into main Nov 11, 2025
120 of 127 checks passed
@Fridge003 Fridge003 deleted the baizhou/fix-dp branch November 11, 2025 22:44
wenscarl added a commit to wenscarl/sglang that referenced this pull request Nov 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants