Skip to content

Conversation

@luccafong
Copy link
Collaborator

@luccafong luccafong commented Nov 12, 2025

Purpose

Current models with int4 weight such as Kimi-K2-thinking are using CompressedTensorsWNA16MarlinMoEMethod which returns None fused_moe_quant_config and does not have gemm impl, so can not go through the prepare finalize path with DP/EP, only naive all2all backend can be used in DP/EP mode, which is slow.

Adding the missing config and gemm selection to enable other all2all backend like deepep.

Test Plan

VLLM_ALL2ALL_BACKEND=deepep_low_latency vllm serve /data/local/models/oss/Kimi-K2-Thinking -dp 8  --max-model-len 16384 --max-num-seqs 32 --block-size 64 --trust-remote-code --enable-expert-parallel 
VLLM_ALL2ALL_BACKEND=deepep_high_throughput vllm serve /data/local/models/oss/Kimi-K2-Thinking -dp 8  --max-model-len 16384 --max-num-seqs 32 --block-size 64 --trust-remote-code --enable-expert-parallel 

Test Result

lm_eval(gsm8k) on par with main branch with naive a2a backend"VLLM_ALL2ALL_BACKEND=naive vllm serve /data/local/models/oss/Kimi-K2-Thinking -dp 8 --max-model-len 32768 --max-num-seqs 32 --block-size 64 --trust-remote-code --enable-expert-parallel

baseline

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9416|±  |0.0065|
|     |       |strict-match    |     5|exact_match|↑  |0.9416|±  |0.0065|

ht

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9454|±  |0.0063|
|     |       |strict-match    |     5|exact_match|↑  |0.9447|±  |0.0063|

ll

|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9424|±  |0.0064|
|     |       |strict-match    |     5|exact_match|↑  |0.9409|±  |0.0065|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify
Copy link

mergify bot commented Nov 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @luccafong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to add support for DeepEP for Kimi2-thinking models by enabling gemm selection for CompressedTensorsWNA16MarlinMoEMethod. The changes correctly implement get_fused_moe_quant_config and select_gemm_impl to handle int4 weights and select the appropriate Marlin expert implementation. The supportive changes in other files are also correct. However, I've found a critical issue in the implementation of select_gemm_impl which will cause a runtime error. Please see the detailed comment below.

@luccafong luccafong changed the title Support DeepEP for Kimi2-thinking through enabling gemm selection for compressed-tensor marlin wna16 Support DeepEP for Kimi-k2-thinking through enabling gemm selection for compressed-tensor marlin wna16 Nov 12, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@luccafong
Copy link
Collaborator Author

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@mgoin
Copy link
Member

mgoin commented Nov 12, 2025

cc @varun-sundar-rabindranath since you looked into mxfp4 for gpt-oss

Signed-off-by: Lu Fang <[email protected]>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 1562 to 1595
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return None
if self.num_bits != 4:
return None
return int4_w4a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size],
)

def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
return BatchedMarlinExperts(
max_num_tokens=prepare_finalize.max_num_tokens_per_rank(),
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Lose act-order indices when routing through modular Marlin experts

The new select_gemm_impl now returns BatchedMarlinExperts/MarlinExperts for CompressedTensorsWNA16MarlinMoEMethod, but those experts never forward the g_idx* and sort_indices* tensors that are passed to fused_marlin_moe in the non-modular path (apply still calls fused_marlin_moe(..., g_idx1=..., g_idx2=..., sort_indices1=..., sort_indices2=...)). For models quantized with grouped activation ordering (which populate these tensors during process_weights_after_loading), the modular kernel used for DP/EP will silently drop the act-order permutation information, causing incorrect expert outputs when the DeepEP/prepare-finalize path is enabled. Consider wiring the g‑index tensors through the modular Marlin experts or gating the modular path off for act-ordered weights.

Useful? React with 👍 / 👎.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luccafong - This call out seems reasonable. Looks like you'd need to plumb through g_idx1 , g_idx2 , sort_indices1 and sort_indices2 ? Can you please take a look. Thanks.

Copy link
Collaborator Author

@luccafong luccafong Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, seems we will need to touch the base method signature of FusedMoEPrepareAndFinalize to add them, or we init them in MarlinExperts,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me try the later approach

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 1e18fc8

@varun-sundar-rabindranath
Copy link
Contributor

@luccafong - along with

 VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_ATTENTION_BACKEND=FLASHINFER_MLA vllm serve /data/local/models/oss/Kimi-K2-Thinking -dp 8 --kv-cache-dtype fp8 --max-model-len 16384 --max-num-seqs 32 --block-size 64 --trust-remote-code --enable-expert-parallel 

can you also try running deepep_high_throughput and provide lm_eval comparison against main ? Thanks.

Signed-off-by: Lu Fang <[email protected]>
@luccafong
Copy link
Collaborator Author

@luccafong - along with

 VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_ATTENTION_BACKEND=FLASHINFER_MLA vllm serve /data/local/models/oss/Kimi-K2-Thinking -dp 8 --kv-cache-dtype fp8 --max-model-len 16384 --max-num-seqs 32 --block-size 64 --trust-remote-code --enable-expert-parallel 

can you also try running deepep_high_throughput and provide lm_eval comparison against main ? Thanks.

updated with test plan and results

Signed-off-by: Lu Fang <[email protected]>
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @luccafong

@luccafong
Copy link
Collaborator Author

@mgoin @houseroad could you review to get committer approval? thanks!

Signed-off-by: Lu Fang <[email protected]>
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 13, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice!

@github-project-automation github-project-automation bot moved this to In review in NVIDIA Nov 13, 2025
@luccafong luccafong merged commit 7e082bc into vllm-project:main Nov 13, 2025
56 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Nov 13, 2025
geodavic pushed a commit to geodavic/vllm that referenced this pull request Nov 16, 2025
…or compressed-tensor marlin wna16 (vllm-project#28574)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: George D. Torres <[email protected]>
bwasti pushed a commit to bwasti/vllm that referenced this pull request Nov 17, 2025
…or compressed-tensor marlin wna16 (vllm-project#28574)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Bram Wasti <[email protected]>
@luccafong luccafong deleted the kimi_moe_marlin branch November 18, 2025 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants