Skip to content

[Quantization] add humming mxfp4 moe backend#41083

Merged
vllm-bot merged 7 commits intovllm-project:mainfrom
jinzhen-lin:humming_mxfp4_moe_backend
May 3, 2026
Merged

[Quantization] add humming mxfp4 moe backend#41083
vllm-bot merged 7 commits intovllm-project:mainfrom
jinzhen-lin:humming_mxfp4_moe_backend

Conversation

@jinzhen-lin
Copy link
Copy Markdown
Contributor

@jinzhen-lin jinzhen-lin commented Apr 28, 2026

This PR add humming mxfp4 moe backend.

Humming project: https://github.com/inclusionAI/humming/

In #34556 , we add an initial integration of humming. Now I am working to integrate humming backends to dense/moe kernel oracles.

This backend supports running DeepSeek-V4.

Note for users: You must pass the --moe-backend humming flag to use this backend, as Humming is not currently a mandatory dependency for vLLM. I am working on releasing the Humming PyPI package. Until then, you can install it using:

pip install git+https://github.com/inclusionAI/humming.git

Benchmark

DeepSeek-V4-Flash + H20 x 4

Service start command:

# Marlin W4A16 (Baseline) (Use command from vllm recipes)
vllm serve /home/admin/DeepSeek-V4-Flash/ \
  --served-model-name model \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --enable-expert-parallel \
  --data-parallel-size 4 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE", "custom_ops":["all"]}' \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4


# Humming W4A16
vllm serve /home/admin/DeepSeek-V4-Flash/ \
  --served-model-name model \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --enable-expert-parallel \
  --data-parallel-size 4 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE", "custom_ops":["all"]}' \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4 \
  --moe-backend humming


# Humming W4A8
VLLM_HUMMING_INPUT_QUANT_CONFIG='{"dtype": "float8e4m3"}' vllm serve /home/admin/DeepSeek-V4-Flash/ \
  --served-model-name model \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --enable-expert-parallel \
  --data-parallel-size 4 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE", "custom_ops":["all"]}' \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4 \
  --moe-backend humming

Bench command:

# for prefill
vllm bench serve \
    --model model \
    --host 127.0.0.1 \
    --tokenizer /home/admin/DeepSeek-V4-Flash/ \
    --num-prompts 128 \
    --random-input-len 65536 \
    --random-output-len 1 \
    --trust-remote-code \
    --max-concurrency 32

# for decoding
vllm bench serve \
    --model model \
    --host 127.0.0.1 \
    --tokenizer /home/admin/DeepSeek-V4-Flash/ \
    --num-prompts 256 \
    --random-input-len 1 \
    --random-output-len 4096 \
    --trust-remote-code \
    --max-concurrency 64

Bench result (TPS):

Prefill Decoding
Marlin W4A16 10532.12 1650.82
Humming W4A16 15052.24 1727.73
Humming W4A8 19404.33 1730.07

The performance gains are primarily driven by enhancements in the moe kernel and the moe sum kernel.

marlin moe w4a16 kernel

image

humming moe w4a16 kernel

image

marlin moe sum kernel (torch.sum)

image

humming moe sum kernel (introduced in #34556)

image

Accuracy Test

lm_eval \
    --model local-chat-completions \
    --tasks gsm8k \
    --num_fewshot 2 \
    --batch_size auto \
    --model_args "model=model,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=128"

Marlin W4A16 (main)

Requesting API: 100%|██████████| 1319/1319 [03:23<00:00,  6.47it/s]
2026-04-28:21:01:31 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-chat-completions ({'model': 'model', 'base_url': 'http://localhost:8000/v1/chat/completions', 'num_concurrent': 128}), gen_kwargs: ({}), limit: None, num_fewshot: 2, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     2|exact_match|↑  |0.9272|±  |0.0072|
|     |       |strict-match    |     2|exact_match|↑  |0.1486|±  |0.0098|

Humming W4A16 (PR)

Requesting API: 100%|██████████| 1319/1319 [02:20<00:00,  9.40it/s]
2026-04-28:21:07:06 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-chat-completions ({'model': 'model', 'base_url': 'http://localhost:8000/v1/chat/completions', 'num_concurrent': 128}), gen_kwargs: ({}), limit: None, num_fewshot: 2, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     2|exact_match|↑  |0.9227|±  |0.0074|
|     |       |strict-match    |     2|exact_match|↑  |0.1471|±  |0.0098|

Humming W4A8 (PR)

Requesting API: 100%|██████████| 1319/1319 [02:21<00:00,  9.33it/s]
2026-04-28:21:12:52 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-chat-completions ({'model': 'model', 'base_url': 'http://localhost:8000/v1/chat/completions', 'num_concurrent': 128}), gen_kwargs: ({}), limit: None, num_fewshot: 2, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     2|exact_match|↑  |0.9219|±  |0.0074|
|     |       |strict-match    |     2|exact_match|↑  |0.1433|±  |0.0097|

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates the Humming mixed-precision kernels into the vLLM framework, specifically focusing on Mixture of Experts (MoE) layers and MXFP4 quantization support. Key changes include the addition of the 'humming' kernel type, refactoring expert classes to use standardized configuration objects, and the introduction of a new utility module, humming_utils.py, for layer preparation and quantization configuration. The review feedback identifies several critical issues in the new code: an unnecessary self parameter in the standalone function humming_is_layer_skipped and the static method humming_gemm_type which would lead to runtime errors, as well as a logic error in the shape_config calculation for non-gated models.

Comment thread vllm/model_executor/layers/quantization/utils/humming_utils.py Outdated
Comment thread vllm/model_executor/layers/fused_moe/fused_humming_moe.py Outdated
Comment thread vllm/model_executor/layers/quantization/utils/humming_utils.py
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 28, 2026

Hi @jinzhen-lin, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 28, 2026

Hi @jinzhen-lin, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

1 similar comment
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 28, 2026

Hi @jinzhen-lin, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 28, 2026

Hi @jinzhen-lin, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
@huangzhilin-hzl
Copy link
Copy Markdown
Contributor

@jinzhen-lin Nice job. Could you add the corresponding accuracy comparison too?

@jinzhen-lin
Copy link
Copy Markdown
Contributor Author

Hi @huangzhilin-hzl , thank you for your interest. The accucry tests are done now.

@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed deepseek Related to DeepSeek models nvidia labels Apr 30, 2026
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work! Validated locally on H100 DSV4 with the latest humming

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA May 3, 2026
@vllm-bot vllm-bot merged commit 08834cc into vllm-project:main May 3, 2026
70 of 77 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA May 3, 2026
joa-stdn pushed a commit to joa-stdn/vllm that referenced this pull request May 4, 2026
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Joachim Studnia <joachim@mistral.ai>
@@ -60,49 +59,44 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular):
def __init__(
self,
layer: torch.nn.Module,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try to avoid passing the layer here if at all possible. It contains the modular kernels. If we ever construct the modular kernels at __init__ time of the layer (which we are considering) then this will lead to all sorts of problems.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Humming supports a wide variety of quantization combinations, the corresponding weight combinations are also quite numerous. To reduce the complexity on the caller side, I prefer to use a layer-based approach. If directly passing the FusedMoE layer would cause issues, do you think it would be a good choice to directly extract all the required weights and reconstruct a temporary layer inside the modular kernels.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand what "construct the modular kernels at __init__ time of the layer" means. Since the modular kernels currently require passing in a FusedMoEQuantConfig, and this config can only be fully defined after process_weights_after_loading, how are we supposed to construct the modular kernels at the __init__ stage? Do you plan to pass these in as runtime variables?

Copy link
Copy Markdown
Collaborator

@bnellnm bnellnm May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand what "construct the modular kernels at __init__ time of the layer" means. Since the modular kernels currently require passing in a FusedMoEQuantConfig, and this config can only be fully defined after process_weights_after_loading, how are we supposed to construct the modular kernels at the __init__ stage? Do you plan to pass these in as runtime variables?

Even though the modular kernels require a FusedMoEQuantConfig at construction time, they don't really need much information from it (if any). We've been discussing removing this as a requirement for construction so that modular kernels can be instantiated at the same time as the quant methods that own them. This is to address other subtle order of initialization issues related to the FusedMoE layer, quant methods, SharedExperts, MoERunner, etc.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, are you planning to pass model parameters or layers as arguments to the apply function? (Many quantization methods have additional parameters besides weight and scale.) I can do the relevant refactoring work for humming in advance.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, are you planning to pass model parameters or layers as arguments to the apply function? (Many quantization methods have additional parameters besides weight and scale.) I can do the relevant refactoring work for humming in advance.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the layer will still be passed as a runtime arg to apply. It's only a problem when used as an argument to __init__ any modular kernel objects.

chaojun-zhang pushed a commit to chaojun-zhang/vllm that referenced this pull request May 6, 2026
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request May 7, 2026
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants