Skip to content

[mxfp4] Remove unnecessary process_weights_after_loading handling in case simulation is used#26111

Closed
fxmarty-amd wants to merge 1 commit intovllm-project:mainfrom
fxmarty-amd:remove-unnecessary-mxfp4-simulation
Closed

[mxfp4] Remove unnecessary process_weights_after_loading handling in case simulation is used#26111
fxmarty-amd wants to merge 1 commit intovllm-project:mainfrom
fxmarty-amd:remove-unnecessary-mxfp4-simulation

Conversation

@fxmarty-amd
Copy link
Contributor

@fxmarty-amd fxmarty-amd commented Oct 2, 2025

As per title.

#25135 added some code in the simulation (QDQ) case for process_weights_after_loading, which is not necessary, and which was not really the purpose of the PR.

See my comments #25135 (comment) and #25135 (comment)

cc @maleksan85

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
@fxmarty-amd fxmarty-amd force-pushed the remove-unnecessary-mxfp4-simulation branch from 0b64348 to d6c0f13 Compare October 2, 2025 16:54
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly removes unnecessary and buggy weight processing logic from process_weights_after_loading for the mxfp4 emulation (simulation) mode. The removed code was dequantizing the weights prematurely during the model loading phase. This conflicted with the apply_weights method, which is designed to perform dequantization at runtime during the forward pass. By removing this block, the bug is fixed, and weights are correctly handled, remaining in their quantized format until they are needed for computation. The change is clean, well-justified, and improves the correctness of the implementation. No issues were found in this change.

Copy link
Contributor

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Should not dequant weight during loading. It is handled at inference time by dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype).

@gshtras
Copy link
Collaborator

gshtras commented Oct 2, 2025

cc @maleksan85

@maleksan85
Copy link
Contributor

did you check that fp4 model runs on MI300 where dequantization is needed?

@fxmarty-amd
Copy link
Contributor Author

@maleksan85 Yes, there are tests in e.g.

def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):

For context: #17888 (comment)

@maleksan85
Copy link
Contributor

maleksan85 commented Oct 6, 2025

if you don't mind, please correct implementation to

root@banff-cyxtera-s73-5:~/workspace/vllm# git diff
diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
index 94c0698eb..de1839770 100644
--- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
+++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
@@ -219,11 +219,8 @@ class QuarkW4A4MXFP4(QuarkScheme):
                       x_quant_scales: torch.Tensor = None) -> torch.Tensor:

         if self.emulate:
-            dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
-
             x = quant_dequant_mxfp4(x)
-
-            return F.linear(x, dq_w, bias)
+            return F.linear(x, layer.weight, bias)
         else:
             return torch.ops.vllm.gemm_with_dynamic_quant(
                 x, layer.weight, layer.weight_scale, x_quant_scales, self.out_dtype)

offline dequant seems better for perf then runtime. cc @BowenBao if this option is being supported.

Thanks.

@mergify
Copy link

mergify bot commented Oct 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fxmarty-amd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 7, 2025
@BowenBao
Copy link
Contributor

BowenBao commented Oct 7, 2025

Hi @maleksan85 ideally we'd prefer the current option as it keeps low memory consumption. This help with experimenting large models on mi300x.

@maleksan85
Copy link
Contributor

maleksan85 commented Oct 7, 2025

Hi @maleksan85 ideally we'd prefer the current option as it keeps low memory consumption. This help with experimenting large models on mi300x.

probably the worth to keep a flag what is preferable... I mean either speed or memory with default to memory.

@fxmarty-amd
Copy link
Contributor Author

This fix was merged as part of #21166, see https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py

The flag indeed existed, but was removed as per Michael request (#17888 (comment)). I agree it may be useful to have it, but at the same time the inflation of rocm env variables is something we'd wish to limit #21138

@fxmarty-amd fxmarty-amd closed this Oct 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants