Skip to content

Bugfix: fix symm not enabled due to incorrect registration of comm#19329

Open
wangfakang wants to merge 10 commits intosgl-project:mainfrom
wangfakang:bugfix_symm
Open

Bugfix: fix symm not enabled due to incorrect registration of comm#19329
wangfakang wants to merge 10 commits intosgl-project:mainfrom
wangfakang:bugfix_symm

Conversation

@wangfakang
Copy link
Copy Markdown
Contributor

@wangfakang wangfakang commented Feb 25, 2026

CC @ShangmingCai @nvcastet @Fridge003 @BBuf @yizhang2077 @ch-wan PTAL, thx.

Motivation

fix symm not enabled due to incorrect registration of comm.

  1. get_local_dp_buffer uses the group obtained by get_tp_group by default to register symm.
  2. The attn_cp_all_gather_into_tensor operation uses the group obtained by get_attention_cp_group to perform allgather.
  3. Steps 1 and 2 cause symm to fail when enabled because the two groups are inconsistent.

def _gather_hidden_states_and_residual(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
layernorm: torch.nn.Module,
context: CommunicateContext,
*,
residual_input_mode,
):
if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual)
# for prefill: attn tp scattered -> full
# for decode: attn tp full -> full
if nsa_use_prefill_cp(forward_batch):
assert context.attn_dp_size == 1
hidden_states, local_hidden_states = (
get_local_dp_buffer(),
hidden_states,
)
attn_cp_all_gather_into_tensor(
hidden_states,
local_hidden_states,
)

def get_local_dp_buffer(cls) -> torch.Tensor:
with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):

def attn_cp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_cp_group().all_gather_into_tensor(output, input)

# Set the env var to pass this argument to the C functions.
os.environ["SGLANG_TMP_NCCL_COMM_VALUE"] = str(
self.group_coordinator.pynccl_comm.comm.value
)

const char *str_val = getenv("SGLANG_TMP_NCCL_COMM_VALUE");
char *endptr;
void* int_val = (void *)strtoull(str_val, &endptr, 0);
ncclComm_t comm = (ncclComm_t)(int_val);
ncclWindow_t win;
NCCLCHECK(ncclCommWindowRegister(comm, ptr, size, &win, NCCL_WIN_COLL_SYMMETRIC));

Modifications

Modify the functions get_local_dp_buffer and get_global_dp_buffer to add the GroupOrdinator parameter, which can make the comm group registered by symm consistent with the comm group of the collection operator.

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Signed-off-by: wangfakang <fakangwang@gmail.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @wangfakang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug preventing symmetric memory from being properly enabled in distributed operations. The fix involves enhancing the get_local_dp_buffer and get_global_dp_buffer functions to explicitly accept a communication group, thereby resolving inconsistencies between memory registration and collective communication patterns. This change ensures that distributed attention and MoE operations utilize the correct communication groups, leading to proper functionality of symmetric memory.

Highlights

  • Communication Group Inconsistency: Identified and fixed a bug where symmetric memory (symm) was not correctly enabled due to get_local_dp_buffer and get_global_dp_buffer using a default communication group (get_tp_group()) that was inconsistent with the group used by collective operations like attn_cp_all_gather_into_tensor (get_attention_cp_group()).
  • API Modification for Group Coordination: Modified the get_local_dp_buffer and get_global_dp_buffer functions to accept an explicit GroupCoordinator parameter, allowing callers to specify the correct communication group for symmetric memory registration.
  • Consistent Group Usage: Updated various call sites across communicator.py, communicator_nsa_cp.py, and token_dispatcher/standard.py to pass the appropriate communication group (e.g., get_attention_tp_group(), get_attention_cp_group(), or get_tp_group()) to get_local_dp_buffer and get_global_dp_buffer, ensuring consistency between symm registration and collective operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/communicator.py
    • Imported get_attention_tp_group.
    • Updated calls to get_local_dp_buffer within _scattered_to_tp_attn_full, _gather_hidden_states_and_residual, _scatter_hidden_states, and _gather to pass get_attention_tp_group().
    • Updated a call to get_global_dp_buffer within _gather_hidden_states_and_residual to pass get_tp_group().
    • Added logic in _scatter_hidden_states to determine the appropriate group (get_tp_group() or get_attention_tp_group()) based on tensor model parallel world size and attention data parallel size.
  • python/sglang/srt/layers/communicator_nsa_cp.py
    • Imported get_attention_cp_group.
    • Updated the call to get_local_dp_buffer within _gather_hidden_states_and_residual to pass get_attention_cp_group().
  • python/sglang/srt/layers/dp_attention.py
    • Modified the class methods _DpGatheredBufferWrapper.get_global_dp_buffer and _DpGatheredBufferWrapper.get_local_dp_buffer to accept a group: GroupCoordinator parameter.
    • Modified the standalone functions get_global_dp_buffer and get_local_dp_buffer to accept and pass through the group parameter to their respective wrapper methods.
  • python/sglang/srt/layers/moe/token_dispatcher/standard.py
    • Updated the call to get_local_dp_buffer within the combine method to pass get_tp_group().
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a bug where symmetric memory (symm) was not being enabled due to incorrect communication group registration. The fix involves adding a group parameter to get_local_dp_buffer and get_global_dp_buffer to make the communication group explicit. This change is consistently applied across various call sites, ensuring that the appropriate communication group (get_attention_tp_group, get_tp_group, or get_attention_cp_group) is used for symmetric memory registration, aligning it with the group used by the corresponding collective communication operator. The implementation is clean and effectively resolves the issue.

Signed-off-by: wangfakang <fakangwang@gmail.com>
Comment on lines -129 to +130
def get_local_dp_buffer(cls) -> torch.Tensor:
with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):
def get_local_dp_buffer(cls, group: GroupCoordinator) -> torch.Tensor:
with use_symmetric_memory(group, disabled=not cls._dp_max_padding):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make group: Optional[GroupCoordinator] = None, if group==None, then we still use get_tp_group() by default.

Copy link
Copy Markdown
Contributor Author

@wangfakang wangfakang Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make group: Optional[GroupCoordinator] = None, if group==None, then we still use get_tp_group() by default.

Using default values can easily lead to group inconsistency, so it's necessary to explicitly declare the correct group to ensure consistency with the communication operator's context.

@nvcastet
Copy link
Copy Markdown
Collaborator

nvcastet commented Feb 26, 2026

@wangfakang Thanks for your PR!
I believe, there is an issue here when multiple groups (communicators) are used with "use_symmetric_memory" context manager in the same run.
The current behavior is:

  • If the pool does not have enough inactive memory, it will allocate and register memory with the specified group.
  • If the pool has enough inactive memory, it just returns a memory block from the pool (without considering the specified group).

So we would need to come up with a design to address this issue: aka making sure the memory returned by the pool is properly registered with the correct group.

In a previous design, we decoupled registration from allocation and we could register or re-register memory segments at the exit of the context manager but now we perform the registration with the allocation since we did not want to pay the CPU cost of pytorch memory snapshot API.
I guess we could come back to this earlier design but tracking the allocations ourself in C++ instead of going through memory snapshot.
CC @merrymercy

@wangfakang
Copy link
Copy Markdown
Contributor Author

wangfakang commented Feb 27, 2026

@wangfakang Thanks for your PR! I believe, there is an issue here when multiple groups (communicators) are used with "use_symmetric_memory" context manager in the same run. The current behavior is:

  • If the pool does not have enough inactive memory, it will allocate and register memory with the specified group.
  • If the pool has enough inactive memory, it just returns a memory block from the pool (without considering the specified group).

So we would need to come up with a design to address this issue: aka making sure the memory returned by the pool is properly registered with the correct group.

In a previous design, we decoupled registration from allocation and we could register or re-register memory segments at the exit of the context manager but now we perform the registration with the allocation since we did not want to pay the CPU cost of pytorch memory snapshot API. I guess we could come back to this earlier design but tracking the allocations ourself in C++ instead of going through memory snapshot. CC @merrymercy

@nvcastet You're absolutely right. Indeed, there are two distinct issues here and this PR fixes the first issue (missing group passing). The second issue (memory snapshot inconsistency) remains and needs a separate solution, possibly via C++ tracking or reverting to decoupled registration.

@nvcastet
Copy link
Copy Markdown
Collaborator

nvcastet commented Feb 27, 2026

Please, if possible use your own words to answer review comments instead of AI. It is easier when we have more direct and accurate back and forth.
I don't think they are distinct issues. I think it would be better to fix the other issue first before merging this one since this one would give you the impression buffers are registered but they are actually still associated with the tp group.
And before merging those fixes, we would need to re-run key configs for DSR1, and qwen to check if we did not break anything.

@wangfakang
Copy link
Copy Markdown
Contributor Author

wangfakang commented Mar 2, 2026

Please, if possible use your own words to answer review comments instead of AI. It is easier when we have more direct and accurate back and forth. I don't think they are distinct issues. I think it would be better to fix the other issue first before merging this one since this one would give you the impression buffers are registered but they are actually still associated with the tp group. And before merging those fixes, we would need to re-run key configs for DSR1, and qwen to check if we did not break anything.

@nvcastet I apologize for the confusion. Would reverting to the previous design require another PR? Thank you.

@nvcastet
Copy link
Copy Markdown
Collaborator

nvcastet commented Mar 2, 2026

Please, if possible use your own words to answer review comments instead of AI. It is easier when we have more direct and accurate back and forth. I don't think they are distinct issues. I think it would be better to fix the other issue first before merging this one since this one would give you the impression buffers are registered but they are actually still associated with the tp group. And before merging those fixes, we would need to re-run key configs for DSR1, and qwen to check if we did not break anything.

@nvcastet I apologize for the confusion. Would reverting to the previous design require another PR? Thank you.

@wangfakang:
@merrymercy can shine in on that topic, but the issue with rolling back was the snapshot() API cost, see previous version at https://github.com/sgl-project/sglang/pull/12524/changes#diff-1857ea2e79f03309e0776136d7e45b432e0369c20ff8a57d418d68c764bb733f.
We would need a new design decoupling registration without the cpu overhead of snapshot().

@wangfakang
Copy link
Copy Markdown
Contributor Author

wangfakang commented Mar 9, 2026

Please, if possible use your own words to answer review comments instead of AI. It is easier when we have more direct and accurate back and forth. I don't think they are distinct issues. I think it would be better to fix the other issue first before merging this one since this one would give you the impression buffers are registered but they are actually still associated with the tp group. And before merging those fixes, we would need to re-run key configs for DSR1, and qwen to check if we did not break anything.

@nvcastet I apologize for the confusion. Would reverting to the previous design require another PR? Thank you.

@wangfakang: @merrymercy can shine in on that topic, but the issue with rolling back was the snapshot() API cost, see previous version at https://github.com/sgl-project/sglang/pull/12524/changes#diff-1857ea2e79f03309e0776136d7e45b432e0369c20ff8a57d418d68c764bb733f. We would need a new design decoupling registration without the cpu overhead of snapshot().

@nvcastet @merrymercy I refactored SymmPool to replace the global MemPool with per-group MemPool dictionary #20153. Now each communication group has its own MemPool, ensuring proper memory registration and preventing cross-group allocation issues in multi-comm scenarios.
Now, Combining #19329 and #20153 can completely fix the two types of issues mentioned by @nvcastet. I can observe all the effects through the nccl tuning log when I enable symm locally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants