Skip to content

[ROCM] DSfp4 mla projection gemms weight dynamic quantization#32238

Merged
gshtras merged 6 commits intovllm-project:mainfrom
ROCm:dsfp4_gemms_part1
Jan 15, 2026
Merged

[ROCM] DSfp4 mla projection gemms weight dynamic quantization#32238
gshtras merged 6 commits intovllm-project:mainfrom
ROCm:dsfp4_gemms_part1

Conversation

@maleksan85
Copy link
Copy Markdown
Contributor

@maleksan85 maleksan85 commented Jan 13, 2026

Commands:

Server:

VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_FP4BMM=1 VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER_MLA=1 AMDGCN_USE_BUFFER_OPS=1 SAFETENSORS_FAST_GPU=1 vllm serve /data/models/DeepSeek-R1-0528-MXFP4-Preview --host localhost --port 8000 --tensor-parallel-size 8 --max-num-batched-tokens 32768 --trust-remote-code --no-enable-prefix-caching --disable-log-requests --gpu_memory_utilization 0.8 --async-scheduling --block-size 16 --load-format fastsafetensors --seed 123 --enforce-eager

Client:

curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
"model": "/data/models/DeepSeek-R1-0528-MXFP4-Preview",
"prompt": "San Francisco is a",
"max_tokens": 16,
"temperature": 0
}'

Output:

{"id":"cmpl-9afa69e020697740","object":"text_completion","created":1768286231,"model":"/data/models/DeepSeek-R1-0528-MXFP4-Preview","choices":[{"index":0,"text":" city known for its vibrant culture, stunning architecture, and diverse neighborhoods. Among its","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":5,"total_tokens":21,"completion_tokens":16,"prompt_tokens_details":null},"kv_transfer_params":null}

legacy: VLLM_ROCM_USE_AITER_FP4BMM=0 VLLM_ROCM_USE_AITER_FP8BMM=0 also works

Aleksandr Malyshev added 3 commits January 13, 2026 05:20
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@mergify mergify Bot added rocm Related to AMD ROCm v1 labels Jan 13, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DSfp4 MLA projection GEMMs with dynamic weight quantization on ROCm. The changes look good overall, adding new environment variables and logic to handle FP4 batched matrix multiplication. I've found a critical issue and a high-severity typo that should be addressed.

Comment thread vllm/model_executor/layers/attention/mla_attention.py Outdated
Comment thread vllm/model_executor/layers/quantization/quark/utils.py Outdated
Comment thread vllm/v1/attention/backends/mla/common.py Outdated
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jan 13, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @maleksan85.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jan 13, 2026
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@maleksan85 maleksan85 marked this pull request as ready for review January 13, 2026 20:25
@@ -2045,7 +2074,7 @@ def forward(
scale=layer._k_scale,
)

if fp8_attention:
if fp8_attention and not self.is_aiter_triton_fp4_bmm_enabled:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 KV cache view skipped when FP4BMM enabled

High Severity

The condition fp8_attention and not self.is_aiter_triton_fp4_bmm_enabled can cause data corruption when both FP8 KV cache and FP4BMM are enabled simultaneously. The concat_and_cache_mla call at line 2068 writes FP8-encoded data when kv_cache_dtype is "fp8", but line 2077 skips the .view(fp8_dtype) conversion when is_aiter_triton_fp4_bmm_enabled is True. Since these conditions are independent (one depends on KV cache dtype, the other on weight dtype and feature flag), both can be True, causing subsequent operations to read FP8 bytes as the original dtype, producing garbage results.

Fix in Cursor Fix in Web

transpose_bm=True,
prequant=True,
y_scale=None,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base class missing FP4BMM attribute setup in process_weights_after_loading

Medium Severity

MLACommonBaseImpl._v_up_proj was modified to use self.W_V and self.W_V_scale when is_aiter_triton_fp4_bmm_enabled is True, but MLACommonBaseImpl.process_weights_after_loading only sets these attributes for FP8BMM (in the if self.is_aiter_triton_fp8_bmm_enabled branch). When FP4BMM is enabled but FP8BMM is disabled, process_weights_after_loading falls into the else branch which sets W_UV and W_UK_T instead, leaving W_V and W_V_scale undefined. This causes the base class to be in an inconsistent state. While MLACommonImpl handles FP4BMM setup correctly, any direct subclass of MLACommonBaseImpl would hit an AttributeError at runtime.

Additional Locations (1)

Fix in Cursor Fix in Web

@mergify mergify Bot removed the needs-rebase label Jan 13, 2026
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 15, 2026
@gshtras gshtras merged commit 8c11001 into vllm-project:main Jan 15, 2026
58 of 59 checks passed
@gshtras gshtras deleted the dsfp4_gemms_part1 branch January 15, 2026 20:13
gshtras pushed a commit to ROCm/vllm that referenced this pull request Jan 15, 2026
…roject#32238)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
sammysun0711 pushed a commit to sammysun0711/vllm that referenced this pull request Jan 16, 2026
…roject#32238)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…roject#32238)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
danichan-mkm pushed a commit to danichan-mkm/vllm that referenced this pull request Feb 11, 2026
…roject#32238)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants