Skip to content

[Quantization] add humming kernel support for deepseek v4#24289

Open
jinzhen-lin wants to merge 11 commits intosgl-project:deepseek_v4from
jinzhen-lin:humming_deepseek_v4
Open

[Quantization] add humming kernel support for deepseek v4#24289
jinzhen-lin wants to merge 11 commits intosgl-project:deepseek_v4from
jinzhen-lin:humming_deepseek_v4

Conversation

@jinzhen-lin
Copy link
Copy Markdown

This PR add humming kernels to SGLang. This PR is based on #23754 , adding and improving support for DeepSeek V4 on top of it.

Humming Kenrels: https://github.com/inclusionAI/humming

vLLM supports:

Humming is a universal, high-performance quantization kernel (similar to the Marlin kernel), but offers several advantages over Marlin:

  • Extensive Quantization Support: Supports all combinations of W{1,2,3,4,5,6,7,8}A{16,8,4} for quantization inference.
  • Superior Performance: Humming outperforms Marlin, especially in large batch scenarios and on Hopper GPUs.
  • Enhanced JIT Support: Compared to the current Marlin JIT implementation in SGLang, Humming offers faster compilation.
  • For DeepSeek V4, Humming supports high-performance W4A8 implementation on Hopper.
  • Support DeepEP

Benchmark

Service start command

# marlin w4a16
sglang serve \
  --trust-remote-code \
  --model-path /home/admin/DeepSeek-V4-Flash \
  --tp 4 \
  --moe-runner-backend marlin \
  --tool-call-parser deepseekv4 \
  --reasoning-parser deepseek-v4 \
  --host 0.0.0.0 \
  --port 12321

# humming w4a16
sglang serve \
  --trust-remote-code \
  --model-path /home/admin/DeepSeek-V4-Flash \
  --tp 4 \
  --moe-runner-backend humming \
  --tool-call-parser deepseekv4 \
  --reasoning-parser deepseek-v4 \
  --host 0.0.0.0 \
  --port 12321

# humming w4a8
SGLANG_HUMMING_INPUT_QUANT_CONFIG='{"dtype": "float8e4m3"}' sglang serve \
  --trust-remote-code \
  --model-path /home/admin/DeepSeek-V4-Flash \
  --tp 4 \
  --moe-runner-backend humming \
  --tool-call-parser deepseekv4 \
  --reasoning-parser deepseek-v4 \
  --host 0.0.0.0 \
  --port 12321

Benchmark command

# for prefill
python3 -m sglang.bench_serving \
    --backend sglang \
    --tokenizer /home/admin/DeepSeek-V4-Flash/ \
    --port 12321 \
    --dataset-name random-ids \
    --random-input-len 65536 \
    --random-output-len 1 \
    --num-prompts 128 \
    --max-concurrency 32 \
    --random-range-ratio 1

# for decoding
python3 -m sglang.bench_serving \
    --backend sglang \
    --tokenizer /home/admin/DeepSeek-V4-Flash/ \
    --port 12321 \
    --dataset-name random-ids \
    --random-input-len 1 \
    --random-output-len 4096 \
    --num-prompts 256 \
    --max-concurrency 64 \
    --random-range-ratio 1

Benchamrk result (TPS)

Prefill Decoding
Marlin W4A16 7272.49 2286.72
Humming W4A16 9060.34 2278.49
Humming W4A8 9917.43 2335.14

In SGLang, splitkv_mla and paged_mqa are used for the prefill part of DeepSeek V4, and the attention part takes longer than expected. If fixed, Humming is expected to achieve a greater e2e improvement.

image

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the "Humming" quantization backend and MoE runner, adding optimized Triton and CUDA kernels for specialized quantization formats like MXFP4. The feedback highlights critical issues such as a potential memory leak in runner registration, possible out-of-bounds memory access in the Triton kernel, and problematic in-place configuration modifications. Additionally, the review suggests fixing a typo in attribute mapping, removing redundant rounding operations, and handling variable data type sizes more accurately during memory allocation.

Comment thread python/sglang/srt/layers/quantization/mxfp4_deepseek.py Outdated
Comment thread python/sglang/srt/layers/moe/fused_moe_triton/moe_fused_mul_sum.py Outdated
Comment thread python/sglang/srt/layers/moe/moe_runner/humming.py
Comment thread python/sglang/srt/layers/quantization/humming.py Outdated
Comment thread python/sglang/srt/layers/moe/moe_runner/humming.py
jinzhen-lin and others added 3 commits May 3, 2026 12:04
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…m.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@jinzhen-lin jinzhen-lin changed the title Humming deepseek v4 [Quantization] add humming kernel support for deepseek v4 May 3, 2026
@jinzhen-lin jinzhen-lin mentioned this pull request May 3, 2026
33 tasks
@fzyzcjy
Copy link
Copy Markdown
Collaborator

fzyzcjy commented May 3, 2026

cc @Fridge003 for hopper w4a16 kernels

@superuct
Copy link
Copy Markdown

superuct commented May 6, 2026

Hello,does DeepSeek-V4 Pro can use the humming kernel?

@jinzhen-lin
Copy link
Copy Markdown
Author

It should be supported, but I haven't actually run it myself yet. Welcome to try and feedback.

@huangzhilin-hzl
Copy link
Copy Markdown
Contributor

@jinzhen-lin Hi, fix sgl-kernel build error jinzhen-lin#1

@huangzhilin-hzl
Copy link
Copy Markdown
Contributor

Fix applying the 2604B SwiGLU clamp/checker path jinzhen-lin#2

@huangzhilin-hzl
Copy link
Copy Markdown
Contributor

fix DeepEP empty-token path error jinzhen-lin#3

@superuct
Copy link
Copy Markdown

superuct commented May 9, 2026

It should be supported, but I haven't actually run it myself yet. Welcome to try and feedback.
Yes. I have tried the Pro model. Humming also worked with your latest code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants