-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[ModelOpt] Remove NVFP4 MoE K%16==0 constraint #26891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: XiaobingSuper <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables tensor parallelism for nv-fp4 MoE layers by removing several assertions in ModelOptNvFp4FusedMoE.process_weights_after_loading. The removed assertions were either redundant or overly strict, preventing the use of tensor parallelism (TP > 1) on certain models.
Specifically:
- The
dtypeassertions forw13_weight_scaleandw2_weight_scalewere redundant, as theswizzle_blockscalefunction already performs this check. - The shape assertions, which required the last dimension of the weight scales to be divisible by 16, were too strict. With tensor parallelism, this dimension can be sharded and may not remain divisible by 16. The
swizzle_blockscalefunction correctly handles this by padding the dimension to a multiple of 4, which is what the underlying kernel requires.
The changes are correct and necessary to support tensor parallelism for these quantized MoE layers. The code is now more robust and consistent with the actual kernel requirements.
mgoin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like these asserts aren't present in the compressed-tensors backend and I was able to run vllm serve RedHatAI/Qwen3-30B-A3B-NVFP4 -tp 2 fine, so I think this was just an oversight for modelopt
vllm/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
Lines 341 to 348 in 136a17f
| # swizzle weight scales | |
| layer.w13_weight_scale = torch.nn.Parameter( | |
| swizzle_blockscale(layer.w13_weight_scale), requires_grad=False | |
| ) | |
| layer.w2_weight_scale = torch.nn.Parameter( | |
| swizzle_blockscale(layer.w2_weight_scale), requires_grad=False | |
| ) |
pavanimajety
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, it's the weight that should be divisble by 16, not the weight scale. Why are we removing the dtype assertion?
pavanimajety
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like the dtype check is happening in swizzle blockscale.
Signed-off-by: XiaobingSuper <[email protected]> Signed-off-by: zhewenli <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]> Signed-off-by: Alberto Perdomo <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]> Signed-off-by: Alberto Perdomo <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]> Signed-off-by: 0xrushi <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]> Signed-off-by: 0xrushi <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]>
Signed-off-by: XiaobingSuper <[email protected]>
This PR is to make the nvfp4 moe kernel work for TP>2. We don't need to check the last dim(K) of the quantized weight is divided by 16, the original weight has guaranteed it.
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.