[ROCM] DSfp4 mla projection gemms weight dynamic quantization#32238
[ROCM] DSfp4 mla projection gemms weight dynamic quantization#32238gshtras merged 6 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for DSfp4 MLA projection GEMMs with dynamic weight quantization on ROCm. The changes look good overall, adding new environment variables and logic to handle FP4 batched matrix multiplication. I've found a critical issue and a high-severity typo that should be addressed.
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
| @@ -2045,7 +2074,7 @@ def forward( | |||
| scale=layer._k_scale, | |||
| ) | |||
|
|
|||
| if fp8_attention: | |||
| if fp8_attention and not self.is_aiter_triton_fp4_bmm_enabled: | |||
There was a problem hiding this comment.
FP8 KV cache view skipped when FP4BMM enabled
High Severity
The condition fp8_attention and not self.is_aiter_triton_fp4_bmm_enabled can cause data corruption when both FP8 KV cache and FP4BMM are enabled simultaneously. The concat_and_cache_mla call at line 2068 writes FP8-encoded data when kv_cache_dtype is "fp8", but line 2077 skips the .view(fp8_dtype) conversion when is_aiter_triton_fp4_bmm_enabled is True. Since these conditions are independent (one depends on KV cache dtype, the other on weight dtype and feature flag), both can be True, causing subsequent operations to read FP8 bytes as the original dtype, producing garbage results.
| transpose_bm=True, | ||
| prequant=True, | ||
| y_scale=None, | ||
| ) |
There was a problem hiding this comment.
Base class missing FP4BMM attribute setup in process_weights_after_loading
Medium Severity
MLACommonBaseImpl._v_up_proj was modified to use self.W_V and self.W_V_scale when is_aiter_triton_fp4_bmm_enabled is True, but MLACommonBaseImpl.process_weights_after_loading only sets these attributes for FP8BMM (in the if self.is_aiter_triton_fp8_bmm_enabled branch). When FP4BMM is enabled but FP8BMM is disabled, process_weights_after_loading falls into the else branch which sets W_UV and W_UK_T instead, leaving W_V and W_V_scale undefined. This causes the base class to be in an inconsistent state. While MLACommonImpl handles FP4BMM setup correctly, any direct subclass of MLACommonBaseImpl would hit an AttributeError at runtime.
Additional Locations (1)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
…roject#32238) Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
…roject#32238) Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
…roject#32238) Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
…roject#32238) Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Commands:
Server:
Client:
Output:
legacy: VLLM_ROCM_USE_AITER_FP4BMM=0 VLLM_ROCM_USE_AITER_FP8BMM=0 also works