Skip to content

[CP] Register KV cache allgather buffer with symmetric memory#24040

Merged
ShangmingCai merged 3 commits intosgl-project:mainfrom
wangfakang:opt-cp
May 6, 2026
Merged

[CP] Register KV cache allgather buffer with symmetric memory#24040
ShangmingCai merged 3 commits intosgl-project:mainfrom
wangfakang:opt-cp

Conversation

@wangfakang
Copy link
Copy Markdown
Contributor

@wangfakang wangfakang commented Apr 29, 2026

cc @ShangmingCai @Fridge003 PTAL, thx.

Motivation

[CP] Fix missing symmetric memory registration in cp_all_gather_reorganized_into_tensor_kv_cache (#22914 follow-up)

When PR #22914 refactored and consolidated NSA utils.py into cp_utils.py, it missed wrapping the KV cache allgather buffer creation with use_symmetric_memory in cp_all_gather_reorganized_into_tensor_kv_cache. This change adds the missing symmetric memory capability to ensure proper buffer registration for improved communication efficiency when symmetric memory is available.

original Register cp-atten-allgather buffers with symm memory

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Signed-off-by: wangfakang <fakangwang@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates cp_utils.py to ensure that tensor allocations are performed within the use_symmetric_memory context manager. A review comment suggests moving a descriptive comment inside the context manager to improve logical grouping and code clarity.

I am having trouble creating individual review comments. Click here to see my feedback.

python/sglang/srt/layers/utils/cp_utils.py (133)

medium

The comment 'Create output tensor with proper shape for all dimensions' is placed outside the context manager, but it describes the logic inside the context manager. It should be moved inside to maintain logical grouping.

    with use_symmetric_memory(
        get_attention_cp_group(), disabled=not is_allocation_symmetric()
    ):
        # Create output tensor with proper shape for all dimensions

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable.

CC: @Shunkangz could you take a look?

@Shunkangz
Copy link
Copy Markdown
Contributor

LGTM. One more small question here. Does it mean that we might need to allocate the memory through ncclMemAlloc in runtime if the required size is larger than the pool size? Because the number of tokens might vary in runtime. If we want to avoid this, we should run a maximum number of tokens request in warm up to avoid this. Is it correct understanding?

@wangfakang
Copy link
Copy Markdown
Contributor Author

LGTM. One more small question here. Does it mean that we might need to allocate the memory through ncclMemAlloc in runtime if the required size is larger than the pool size? Because the number of tokens might vary in runtime. If we want to avoid this, we should run a maximum number of tokens request in warm up to avoid this. Is it correct understanding?

Yes, that's correct. However, currently when SGLang starts up, it performs a check. If symm is enabled, it will by default pre-allocate 4GB of memory for warming up. This is to avoid frequent subsequent allocations due to insufficient memory, which can lead to memory fragmentation issues. Additionally, once this PR for restructuring the symm pool is merged, the memory pool will be shared across various communication.

@Shunkangz
Copy link
Copy Markdown
Contributor

LGTM. One more small question here. Does it mean that we might need to allocate the memory through ncclMemAlloc in runtime if the required size is larger than the pool size? Because the number of tokens might vary in runtime. If we want to avoid this, we should run a maximum number of tokens request in warm up to avoid this. Is it correct understanding?

Yes, that's correct. However, currently when SGLang starts up, it performs a check. If symm is enabled, it will by default pre-allocate 4GB of memory for warming up. This is to avoid frequent subsequent allocations due to insufficient memory, which can lead to memory fragmentation issues. Additionally, once this PR for restructuring the symm pool is merged, the memory pool will be shared across various communication.

Thank you for the detailed explanation.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@wangfakang
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Trigger waiting test task.

@wangfakang
Copy link
Copy Markdown
Contributor Author

wangfakang commented May 6, 2026

Frendly ping @ShangmingCai @Shunkangz
I checked the logs and found that the failing test case is stage-b-test-1-gpu-large (8). The error message is ModuleNotFoundError: No module named 'sglang.srt.layers.moe.fused_moe_triton.fused_moe', which is unrelated to the changes in this PR.

image

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-deepep-8-gpu-h200

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-4-gpu-h100

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 6, 2026

✅ Triggered stage-c-test-deepep-8-gpu-h200 to run independently (skipping dependencies). View workflow run

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 6, 2026

✅ Triggered stage-c-test-4-gpu-h100 to run independently (skipping dependencies). View workflow run

@ShangmingCai
Copy link
Copy Markdown
Collaborator

Related CI has passed.

image image

@ShangmingCai ShangmingCai merged commit bfc1aea into sgl-project:main May 6, 2026
296 of 345 checks passed
ltcs11 added a commit to ltcs11/sglang that referenced this pull request May 7, 2026
* main: (894 commits)
  [Bug Fix] Fix RunAI streamer: corrupted weights, missing quant init, and broken URIs for multimodal models (sgl-project#22715)
  [Kernel] Deprecate DeepGemm in sgl kernel and apply custom wheel sgl-deep-gemm (sgl-project#24268)
  propagate pytest exit code from test __main__ entries (sgl-project#24487)
  [R3] Avoid implicit CUDA sync in routed experts DP slicing (sgl-project#24550)
  Add ChatCompletionRequest-style support to /v1/tokenize (sgl-project#23981)
  Support Triton MLA FP8 KV cache (sgl-project#20479)
  [diffusion] chore: align LTX-2 with official (sgl-project#24313)
  Expand support matrix for pypi wheel release (sgl-project#24565)
  [codex] Optimize Z-Image packed QKV (sgl-project#24117)
  [Misc] Fix breaking weight checker test (sgl-project#24553)
  [LoRA] Fix qkv_proj LoRA buffer sizing when tp_size > num_key_value_heads (sgl-project#24420)
  ci: bump test_mimo_models.py est_time 330 → 610 (sgl-project#24551)
  [CI] Temporarily disable marco/mcdse-2b-v1 in test_embedding_models (sgl-project#24279)
  Improve metrics, observability, and PD deploy tooling (sgl-project#24521)
  Fix diffusion fallback guards and validation (sgl-project#23335)
  [PD] Prevent update_status to Failed from cleared entries (sgl-project#24539)
  [CP] Register KV cache allgather buffer with symmetric memory (sgl-project#24040)
  Support getting checksums in weight checker (sgl-project#24537)
  Refactor buffer patterns in weight checker (sgl-project#24538)
  Add unit and end-to-end tests for weight checker (sgl-project#24536)
  ...

# Conflicts:
#	python/sglang/srt/managers/scheduler.py
#	python/sglang/srt/model_executor/model_runner.py
LLThomas pushed a commit to LLThomas/sglang that referenced this pull request May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants