[MXFP4] Support for linear layers + compressed-tensors integration#41664
Conversation
There was a problem hiding this comment.
Code Review
This pull request renames the MXFP4 quantization scheme to CompressedTensorsW4A4Mxfp4 and introduces support for true W4A4 quantization using FlashInfer on SM100+ devices. The changes include new utility functions for FlashInfer FP4 operations and logic to handle activation quantization. A critical issue was identified in the weight scale processing where swizzle_mxfp4_scales might cause a RuntimeError during reshaping if the output feature size is not a multiple of 128 due to internal padding.
| N, scale_K = layer.weight_scale.shape | ||
| K = scale_K * self.group_size | ||
| layer.weight_scale = Parameter( | ||
| swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(N, -1), |
There was a problem hiding this comment.
The swizzle_mxfp4_scales function pads the N dimension to the nearest multiple of 128. If N (the output feature size) is not a multiple of 128, the total number of elements in the swizzled tensor will be padded_N * padded_scale_cols, which is not necessarily divisible by N. This will cause a RuntimeError during the .reshape(N, -1) call. Even if it were divisible, the resulting 2D tensor would have misaligned scale data because of the padding introduced during swizzling. You should ensure that the scale tensor's shape is compatible with what the FlashInfer kernel expects, which likely involves keeping the padded dimensions or ensuring the kernel handles the original N correctly with the swizzled layout.
Signed-off-by: Dipika <dipikasikka1@gmail.com>
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
SUMMARY: - Move out of experimental as supported in vLLM as of: vllm-project/vllm#41664 --------- Signed-off-by: Dipika Sikka <ds3822@columbia.edu> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
kylesayrs
left a comment
There was a problem hiding this comment.
Support looks good, was able to verify locally
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
|
Hi @dsikka, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
yewentao256
left a comment
There was a problem hiding this comment.
Sorry to block for a while, please take a look at my previous comment
@yewentao256 Please take a look at the latest commits. This has been addressed to use the padded_N for the reshape |
Dismiss request change as already solved
yewentao256
left a comment
There was a problem hiding this comment.
Thanks for the work! A small update
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Dipika Sikka <ds3822@columbia.edu>
Purpose
flashinfer_mm_fp4andflashinfer_scaled_fp4_mmto take in a configurableblock_sizeand boolean flaguse_nvfp4to enable the mxfp4 linear forward pass based on https://github.com/flashinfer-ai/flashinfer/blob/393e83ea8497ff9fb9ad61e170b89797a6b682a3/flashinfer/gemm/gemm_base.py#L5511Test Plan
LM-Eval
flashinfer
marlin (
VLLM_MXFP4_USE_MARLIN=1)dense (meta-llama/Meta-Llama-3-8B-Instruct)