Skip to content

Optimize MHC pipeline: DeepGemm, fused norm, fused hc_head#24775

Merged
yhyang201 merged 7 commits intomainfrom
opt/mhc-pre-fold-reduction
May 10, 2026
Merged

Optimize MHC pipeline: DeepGemm, fused norm, fused hc_head#24775
yhyang201 merged 7 commits intomainfrom
opt/mhc-pre-fold-reduction

Conversation

@yhyang201
Copy link
Copy Markdown
Collaborator

@yhyang201 yhyang201 commented May 9, 2026

Summary

  • Fold split-k stage-1 reduction into big_fuse (eliminates one kernel launch for num_tokens <= 2048)
  • Use DeepGemm tf32_hc_prenorm_gemm for mhc_pre GEMM when SGLANG_OPT_DEEPGEMM_HC_PRENORM is enabled
  • Fuse RMSNorm into mhc_pre_big_fuse kernel (eliminates separate norm kernel launch + HBM round-trip)
  • Add fused Triton kernel for hc_head (fuses RMSNorm + Linear + Sigmoid-gate + weighted-sum into one kernel)

Ported from:

Microbench

Standalone microbench of mhc_pre (with real sglang RMSNorm kernel as baseline) and hc_head, DSV4 params: hidden=7168, hc_mult=4. CUDA event timing, 100 iters, trimmed top/bottom 10%.

norm + mhc_pre (called 2x per decoder layer, both prefill and decode):

tokens GB300 main GB300 PR speedup B300 main B300 PR speedup
1 0.195ms 0.157ms 1.24x 0.046ms 0.034ms 1.35x
32 0.191ms 0.148ms 1.29x 0.045ms 0.034ms 1.32x
128 0.191ms 0.132ms 1.45x 0.045ms 0.033ms 1.36x
512 0.173ms 0.127ms 1.36x 0.049ms 0.035ms 1.40x
1024 0.166ms 0.126ms 1.32x 0.074ms 0.046ms 1.61x
2048 0.176ms 0.129ms 1.36x 0.134ms 0.069ms 1.94x
4096 0.227ms 0.157ms 1.45x 0.189ms 0.113ms 1.67x
8192 0.387ms 0.244ms 1.59x 0.352ms 0.200ms 1.76x

hc_head (called 1x per forward on last PP rank):

tokens GB300 main GB300 PR speedup B300 main B300 PR speedup
1 0.279ms 0.081ms 3.4x 0.084ms 0.065ms 1.3x
32 0.285ms 0.080ms 3.6x 0.158ms 0.065ms 2.4x
128 0.279ms 0.081ms 3.4x 0.176ms 0.066ms 2.7x
512 0.325ms 0.113ms 2.9x 0.285ms 0.099ms 2.9x
1024 0.486ms 0.154ms 3.2x 0.468ms 0.143ms 3.3x
2048 0.769ms 0.303ms 2.5x 0.784ms 0.294ms 2.7x
4096 1.479ms 0.522ms 2.8x 1.516ms 0.527ms 2.9x
8192 2.569ms 1.002ms 2.6x 2.661ms 1.013ms 2.6x

Limitations

  • This is a standalone microbench calling mhc_pre and hc_head in isolation, not an end-to-end serving benchmark. Real-world gains depend on how much time these kernels contribute to total per-token latency.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label May 9, 2026
yhyang201 and others added 4 commits May 9, 2026 08:37
Co-Authored-By: Cheng Wan <chwan@rice.edu>
Co-authored-by: Chunan Zeng <zcnrex@gmail.com>
Co-authored-by: Cheng Wan <chwan@rice.edu>
Co-authored-by: Cheng Wan <chwan@rice.edu>
@yhyang201 yhyang201 force-pushed the opt/mhc-pre-fold-reduction branch from 78e5ce3 to abeb7f8 Compare May 9, 2026 06:51
@yhyang201 yhyang201 changed the title Optimize mhc_pre by folding stage-1 reduction into big_fuse Optimize MHC pipeline: DeepGemm, fused norm, fused hc_head May 9, 2026
@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-dsv4-4-gpu-b200

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-dsv4-8-gpu-h200

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 9, 2026

❌ Stage stage-c-test-dsv4-4-gpu-b200 doesn't support isolated runs yet.

NVIDIA stages:

  • stage-a-test-1-gpu-small
  • stage-a-test-cpu
  • stage-b-test-1-gpu-small
  • stage-b-test-1-gpu-large
  • stage-b-test-2-gpu-large
  • stage-b-test-4-gpu-b200
  • stage-c-test-4-gpu-h100
  • stage-c-test-8-gpu-h200
  • stage-c-test-8-gpu-h20
  • stage-c-test-4-gpu-b200
  • stage-c-test-4-gpu-gb200
  • stage-c-test-deepep-4-gpu-h100
  • stage-c-test-deepep-8-gpu-h200
  • multimodal-gen-test-1-gpu
  • multimodal-gen-test-2-gpu
  • multimodal-gen-component-accuracy
  • multimodal-gen-component-accuracy-1-gpu
  • multimodal-gen-component-accuracy-2-gpu
  • multimodal-gen-test-1-b200

AMD stages:

  • sgl-kernel-unit-test-amd
  • sgl-kernel-unit-test-2-gpu-amd
  • stage-a-test-1-gpu-small-amd
  • stage-b-test-1-gpu-small-amd
  • stage-b-test-1-gpu-small-amd-nondeterministic
  • stage-b-test-1-gpu-small-amd-mi35x
  • stage-b-test-1-gpu-large-amd
  • stage-b-test-2-gpu-large-amd
  • multimodal-gen-test-1-gpu-amd
  • multimodal-gen-test-2-gpu-amd
  • stage-c-test-large-8-gpu-amd
  • stage-c-test-large-8-gpu-amd-mi35x

Other stages will be added soon. For now, use /rerun-failed-ci for those stages.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 9, 2026

❌ Stage stage-c-test-dsv4-8-gpu-h200 doesn't support isolated runs yet.

NVIDIA stages:

  • stage-a-test-1-gpu-small
  • stage-a-test-cpu
  • stage-b-test-1-gpu-small
  • stage-b-test-1-gpu-large
  • stage-b-test-2-gpu-large
  • stage-b-test-4-gpu-b200
  • stage-c-test-4-gpu-h100
  • stage-c-test-8-gpu-h200
  • stage-c-test-8-gpu-h20
  • stage-c-test-4-gpu-b200
  • stage-c-test-4-gpu-gb200
  • stage-c-test-deepep-4-gpu-h100
  • stage-c-test-deepep-8-gpu-h200
  • multimodal-gen-test-1-gpu
  • multimodal-gen-test-2-gpu
  • multimodal-gen-component-accuracy
  • multimodal-gen-component-accuracy-1-gpu
  • multimodal-gen-component-accuracy-2-gpu
  • multimodal-gen-test-1-b200

AMD stages:

  • sgl-kernel-unit-test-amd
  • sgl-kernel-unit-test-2-gpu-amd
  • stage-a-test-1-gpu-small-amd
  • stage-b-test-1-gpu-small-amd
  • stage-b-test-1-gpu-small-amd-nondeterministic
  • stage-b-test-1-gpu-small-amd-mi35x
  • stage-b-test-1-gpu-large-amd
  • stage-b-test-2-gpu-large-amd
  • multimodal-gen-test-1-gpu-amd
  • multimodal-gen-test-2-gpu-amd
  • stage-c-test-large-8-gpu-amd
  • stage-c-test-large-8-gpu-amd-mi35x

Other stages will be added soon. For now, use /rerun-failed-ci for those stages.

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-test registered/4-gpu-models/test_deepseek_v4_flash_fp4_b200.py

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 9, 2026

registered/4-gpu-models/test_deepseek_v4_flash_fp4_b200.py: Unknown CUDA suite stage-c-test-dsv4-4-gpu-b200 in test/registered/4-gpu-models/test_deepseek_v4_flash_fp4_b200.py.

Known suites: nightly-1-gpu, nightly-4-gpu, nightly-4-gpu-b200, nightly-8-gpu-b200, nightly-8-gpu-common, nightly-8-gpu-h20, nightly-8-gpu-h200, nightly-eval-text-2-gpu, nightly-eval-vlm-2-gpu, nightly-kernel-1-gpu, nightly-kernel-8-gpu-h200, nightly-perf-text-2-gpu, nightly-perf-vlm-2-gpu, stage-a-test-1-gpu-small, stage-a-test-cpu, stage-b-test-1-gpu-large, stage-b-test-1-gpu-small, stage-b-test-2-gpu-large, stage-b-test-4-gpu-b200, stage-c-test-4-gpu-b200, stage-c-test-4-gpu-h100, stage-c-test-8-gpu-h20, stage-c-test-8-gpu-h200, stage-c-test-deepep-4-gpu-h100, stage-c-test-deepep-8-gpu-h200, weekly-8-gpu-h200

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-dsv4-4-gpu-b200

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 9, 2026

✅ Triggered stage-c-test-dsv4-4-gpu-b200 to run independently (skipping dependencies). View workflow run

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-dsv4-8-gpu-h20

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-dsv4-8-gpu-h200

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 9, 2026

❌ Stage stage-c-test-dsv4-8-gpu-h20 doesn't support isolated runs yet.

NVIDIA stages:

  • stage-a-test-1-gpu-small
  • stage-a-test-cpu
  • stage-b-test-1-gpu-small
  • stage-b-test-1-gpu-large
  • stage-b-test-2-gpu-large
  • stage-b-test-4-gpu-b200
  • stage-c-test-4-gpu-h100
  • stage-c-test-8-gpu-h200
  • stage-c-test-8-gpu-h20
  • stage-c-test-4-gpu-b200
  • stage-c-test-4-gpu-gb200
  • stage-c-test-dsv4-4-gpu-b200
  • stage-c-test-dsv4-8-gpu-h200
  • stage-c-test-deepep-4-gpu-h100
  • stage-c-test-deepep-8-gpu-h200
  • multimodal-gen-test-1-gpu
  • multimodal-gen-test-2-gpu
  • multimodal-gen-component-accuracy
  • multimodal-gen-component-accuracy-1-gpu
  • multimodal-gen-component-accuracy-2-gpu
  • multimodal-gen-test-1-b200

AMD stages:

  • sgl-kernel-unit-test-amd
  • sgl-kernel-unit-test-2-gpu-amd
  • stage-a-test-1-gpu-small-amd
  • stage-b-test-1-gpu-small-amd
  • stage-b-test-1-gpu-small-amd-nondeterministic
  • stage-b-test-1-gpu-small-amd-mi35x
  • stage-b-test-1-gpu-large-amd
  • stage-b-test-2-gpu-large-amd
  • multimodal-gen-test-1-gpu-amd
  • multimodal-gen-test-2-gpu-amd
  • stage-c-test-large-8-gpu-amd
  • stage-c-test-large-8-gpu-amd-mi35x

Other stages will be added soon. For now, use /rerun-failed-ci for those stages.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 9, 2026

✅ Triggered stage-c-test-dsv4-8-gpu-h200 to run independently (skipping dependencies). View workflow run

T.alloc_fragment does not guarantee zero initialization.
The sumsq_per_pos accumulator must be explicitly cleared
before the pipelined loop to avoid garbage values corrupting
the RMSNorm computation, which caused all-zero model output.

Co-authored-by: Cheng Wan <chwan@rice.edu>
@yhyang201 yhyang201 removed the run-ci label May 9, 2026
@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-dsv4-8-gpu-h200

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-dsv4-4-gpu-b200

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 9, 2026

✅ Triggered stage-c-test-dsv4-8-gpu-h200 to run independently (skipping dependencies). View workflow run

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 9, 2026

✅ Triggered stage-c-test-dsv4-4-gpu-b200 to run independently (skipping dependencies). View workflow run

@yhyang201
Copy link
Copy Markdown
Collaborator Author

all tests related to dpskv4 have passed

@yhyang201 yhyang201 merged commit 2f06867 into main May 10, 2026
167 of 179 checks passed
@yhyang201 yhyang201 deleted the opt/mhc-pre-fold-reduction branch May 10, 2026 11:03
ltcs11 added a commit to ltcs11/sglang that referenced this pull request May 11, 2026
* main: (87 commits)
  [Fix] Disable FlashInfer allreduce fusion under deterministic inference (sgl-project#24629)
  fix: STANDALONE spec-decode hidden-size mismatch crash (sgl-project#24217)
  Followup fix for Custom AR V2 in non NVL scenarios (sgl-project#24742)
  Fix reduce_scatterv producer contract for SUM_LEN (sgl-project#24785)
  [NPU]Documentation update for communications quantization feature (sgl-project#24668)
  [Session R3] Add routed_experts_start_len for absolute routing slice control (sgl-project#24851)
  [Model] Add MiniCPM-V 4.6 support (sgl-project#24855)
  Support Intern-S2-Preview (sgl-project#24875)
  [PD] Unify dsv4 dispatch with swa (sgl-project#24888)
  Optimize MHC pipeline: DeepGemm, fused norm, fused hc_head (sgl-project#24775)
  Fix PD bootstrap failure handling (sgl-project#24772)
  [Spec] Cleanup idle stub and shape-check patterns (sgl-project#24881)
  [Bug] Add dsv4 state_type branch to mooncake disaggregation (sgl-project#24878)
  [Spec V1] Split draft-extend phase from `EagleDraftInput` into new `EagleDraftExtendInput` (sgl-project#24859)
  [Gemma4] Optimize Gemm4 with fused Q/K/V RMSNorm + per-expert FP8 ckpt loader (sgl-project#24696)
  [spec decoding] support kimi-k2.5-eagle3-mla (sgl-project#24826)
  [SPEC V2] fix: skip stale state updates in spec-v2 overlap (sgl-project#23456)
  [RL] Call torch.cuda.empty_cache() for `in-place` pause mode to avoid OOM (sgl-project#24854)
  [diffusion] CI: add cache-dit CI tests (sgl-project#19213)
  [Utils] Make request dump robust to unpicklable server_args and large meta_info (sgl-project#24767)
  ...

# Conflicts:
#	python/sglang/srt/utils/common.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants