Skip to content

[Perf] Deepgemm fused layout kernel for activations, 4.3% throughput improvement, 10.7% TTFT improvement.#29546

Merged
youkaichao merged 6 commits intomainfrom
wentao-deepgemm-fused-layout-kernel
Dec 7, 2025
Merged

[Perf] Deepgemm fused layout kernel for activations, 4.3% throughput improvement, 10.7% TTFT improvement.#29546
youkaichao merged 6 commits intomainfrom
wentao-deepgemm-fused-layout-kernel

Conversation

@yewentao256
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 commented Nov 26, 2025

Purpose

Fused layout transform with per token group quant to get performance

Namely, pack scales into a uint32 earlier and remove an additional kernel call

Test

vllm serve deepseek-ai/DeepSeek-V3.1 -tp 8 --enable-expert-parallel --port 9256 --enforce_eager

Acc

lm_eval --model local-completions --model_args "base_url=http://127.0.0.1:9256/v1/completions,model=deepseek-ai/DeepSeek-V3.1,num_concurrent=1024" --tasks gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9568|±  |0.0056|
|     |       |strict-match    |     5|exact_match||0.9568|±  |0.0056|

Perf

vllm bench serve --model deepseek-ai/DeepSeek-V3.1 --dataset-name random --host 127.0.0.1 --port 9256 --random-input-len 2 --random-output-len 256 --request-rate inf --num-prompts 1024

Now
============ Serving Benchmark Result ============
Successful requests:                     1024      
Failed requests:                         0         
Benchmark duration (s):                  27.94     
Total input tokens:                      3072      
Total generated tokens:                  262144    
Request throughput (req/s):              36.65     
Output token throughput (tok/s):         9382.68   
Peak output token throughput (tok/s):    10240.00  
Peak concurrent requests:                1024.00   
Total Token throughput (tok/s):          9492.63   
---------------Time to First Token----------------
Mean TTFT (ms):                          1013.22   
Median TTFT (ms):                        982.19    
P99 TTFT (ms):                           1157.11   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          105.02    
Median TPOT (ms):                        105.08    
P99 TPOT (ms):                           105.26    
---------------Inter-token Latency----------------
Mean ITL (ms):                           105.02    
Median ITL (ms):                         104.57    
P99 ITL (ms):                            131.10    
==================================================

Main
============ Serving Benchmark Result ============
Successful requests:                     1024      
Failed requests:                         0         
Benchmark duration (s):                  29.15     
Total input tokens:                      3072      
Total generated tokens:                  262144    
Request throughput (req/s):              35.13     
Output token throughput (tok/s):         8994.31   
Peak output token throughput (tok/s):    10132.00  
Peak concurrent requests:                1024.00   
Total Token throughput (tok/s):          9099.71   
---------------Time to First Token----------------
Mean TTFT (ms):                          1121.89   
Median TTFT (ms):                        1144.56   
P99 TTFT (ms):                           1241.25   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          109.30    
Median TPOT (ms):                        109.29    
P99 TPOT (ms):                           109.40    
---------------Inter-token Latency----------------
Mean ITL (ms):                           109.30    
Median ITL (ms):                         108.89    
P99 ITL (ms):                            125.29    
==================================================

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused CUDA kernel for activation quantization and scale packing, targeting performance improvements with DeepGEMM. The changes are well-motivated and backed by performance data showing significant gains. My review focuses on the correctness and maintainability of the new CUDA kernel and its Python integration. I've identified two high-severity issues: one related to obscure and fragile bit-packing logic in the CUDA kernel that should be refactored for clarity and robustness, and another in the Python wrapper which fails to use a pre-allocated output buffer, leading to unnecessary memory allocations.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

torch.ops.vllm.fp8_gemm_nt_op(
q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0

P1 Badge DeepGEMM path packs UE8M0 scales but signals E8M0 off

The DeepGEMM linear path now quantizes activations with use_ue8m0=True, producing UE8M0-packed int32 scales, but the subsequent fp8_gemm_nt_op call still forwards self.use_deep_gemm_e8m0. When VLLM_USE_DEEP_GEMM_E8M0 is false (the default on supported GPUs), this flag is false, so DeepGEMM will interpret the scale buffer as its non-E8M0 float format while it actually contains packed exponents, leading to incorrect matmul results whenever DeepGEMM is used without E8M0 enabled.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@yewentao256 yewentao256 marked this pull request as draft November 26, 2025 21:52
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 marked this pull request as ready for review November 26, 2025 22:19
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 26, 2025
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@yewentao256
Copy link
Copy Markdown
Member Author

@youkaichao CC

@mergify
Copy link
Copy Markdown

mergify bot commented Dec 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yewentao256.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 4, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@mergify mergify bot removed the needs-rebase label Dec 4, 2025
Copy link
Copy Markdown
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the great work!

for context, per offline discussion with @LyricZhao , deepgemm requires this packed scaling factor for optimal performance.

@youkaichao youkaichao merged commit 541a2ef into main Dec 7, 2025
95 of 96 checks passed
@youkaichao youkaichao deleted the wentao-deepgemm-fused-layout-kernel branch December 7, 2025 12:31
@zifeitong
Copy link
Copy Markdown
Contributor

zifeitong commented Dec 8, 2025

Deepseek 3.2 is broken on HEAD, and it looks that it's related to this PR:

(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750] WorkerProc failed to start.
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750] Traceback (most recent call last):
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/executor/multiproc_executor.py", line 722, in worker_main
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     worker = WorkerProc(*args, **kwargs)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/executor/multiproc_executor.py", line 562, in __init__
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     self.worker.load_model()
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/worker/gpu_worker.py", line 278, in load_model
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     self.model_runner.load_model(eep_scale_up=eep_scale_up)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/worker/gpu_model_runner.py", line 3561, in load_model
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     self.model = model_loader.load_model(
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]                  ~~~~~~~~~~~~~~~~~~~~~~~^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         vllm_config=self.vllm_config, model_config=self.model_config
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     )
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/model_loader/base_loader.py", line 56, in load_model
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     process_weights_after_loading(model, model_config, target_device)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/model_loader/utils.py", line 118, in process_weights_after_loading
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     module.process_weights_after_loading(model_config.dtype)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/attention/layer.py", line 730, in process_weights_after_loading
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     self.impl.process_weights_after_loading(act_dtype)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/attention/backends/mla/common.py", line 1179, in process_weights_after_loading
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]                        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/attention/backends/mla/common.py", line 1170, in get_and_maybe_dequant_weights
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/layers/quantization/fp8.py", line 634, in apply
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     return self.w8a8_block_fp8_linear.apply(
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         input=x,
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ...<3 lines>...
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         bias=bias,
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     )
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/layers/quantization/utils/fp8_utils.py", line 255, in apply
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     output = self._run_deepgemm(input_2d, weight, weight_scale)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/layers/quantization/utils/fp8_utils.py", line 282, in _run_deepgemm
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     torch.ops.vllm.fp8_gemm_nt_op(
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     )
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/torch/_ops.py", line 1255, in __call__
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     return self._op(*args, **kwargs)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]            ~~~~~~~~^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/layers/quantization/utils/fp8_utils.py", line 170, in _fp8_gemm_nt_op
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     fp8_gemm_nt(
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         (q_input, input_scale),
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ...<2 lines>...
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     )
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/utils/deep_gemm.py", line 186, in fp8_gemm_nt
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750] RuntimeError: Assertion error (csrc/apis/../jit_kernels/impls/../heuristics/../../utils/layout.hpp:49): sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat

@SongDI911
Copy link
Copy Markdown
Contributor

Deepseek 3.2 is broken on HEAD, and it looks that it's related to this PR:

(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750] WorkerProc failed to start.
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750] Traceback (most recent call last):
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/executor/multiproc_executor.py", line 722, in worker_main
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     worker = WorkerProc(*args, **kwargs)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/executor/multiproc_executor.py", line 562, in __init__
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     self.worker.load_model()
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/worker/gpu_worker.py", line 278, in load_model
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     self.model_runner.load_model(eep_scale_up=eep_scale_up)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/worker/gpu_model_runner.py", line 3561, in load_model
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     self.model = model_loader.load_model(
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]                  ~~~~~~~~~~~~~~~~~~~~~~~^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         vllm_config=self.vllm_config, model_config=self.model_config
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     )
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/model_loader/base_loader.py", line 56, in load_model
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     process_weights_after_loading(model, model_config, target_device)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/model_loader/utils.py", line 118, in process_weights_after_loading
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     module.process_weights_after_loading(model_config.dtype)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/attention/layer.py", line 730, in process_weights_after_loading
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     self.impl.process_weights_after_loading(act_dtype)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/attention/backends/mla/common.py", line 1179, in process_weights_after_loading
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]                        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/v1/attention/backends/mla/common.py", line 1170, in get_and_maybe_dequant_weights
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/layers/quantization/fp8.py", line 634, in apply
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     return self.w8a8_block_fp8_linear.apply(
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         input=x,
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ...<3 lines>...
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         bias=bias,
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     )
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/layers/quantization/utils/fp8_utils.py", line 255, in apply
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     output = self._run_deepgemm(input_2d, weight, weight_scale)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/layers/quantization/utils/fp8_utils.py", line 282, in _run_deepgemm
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     torch.ops.vllm.fp8_gemm_nt_op(
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     )
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/torch/_ops.py", line 1255, in __call__
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     return self._op(*args, **kwargs)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]            ~~~~~~~~^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/model_executor/layers/quantization/utils/fp8_utils.py", line 170, in _fp8_gemm_nt_op
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     fp8_gemm_nt(
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ~~~~~~~~~~~^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         (q_input, input_scale),
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ...<2 lines>...
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     )
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     ^
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]   File "/opt/venv/lib/python3.13/site-packages/vllm/utils/deep_gemm.py", line 186, in fp8_gemm_nt
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750]     return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
(Worker_TP2 pid=280) ERROR 12-08 03:07:48 [multiproc_executor.py:750] RuntimeError: Assertion error (csrc/apis/../jit_kernels/impls/../heuristics/../../utils/layout.hpp:49): sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat

I also encountered this when running DeepSee 3.2 with h200

yeqcharlotte pushed a commit that referenced this pull request Dec 9, 2025
Signed-off-by: zhewenli <zhewenli@meta.com>
Signed-off-by: Zhewen Li <zhewenli@meta.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…improvement, 10.7% TTFT improvement. (vllm-project#29546)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: zhewenli <zhewenli@meta.com>
Signed-off-by: Zhewen Li <zhewenli@meta.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants