[ROCm][Quantization] Enable moe_wna16 on ROCm via Triton fallback#35596
Open
brucechanglongxu wants to merge 1 commit intovllm-project:mainfrom
Open
[ROCm][Quantization] Enable moe_wna16 on ROCm via Triton fallback#35596brucechanglongxu wants to merge 1 commit intovllm-project:mainfrom
brucechanglongxu wants to merge 1 commit intovllm-project:mainfrom
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request enables moe_wna16 quantization on ROCm by ensuring the Triton fallback kernel is used instead of the CUDA-only one, and by adding moe_wna16 to the list of supported quantization methods for the ROCm platform. The changes are correct and address the issue. I have one suggestion to make the platform check more direct and robust.
Comment on lines
1219
to
+1220
| current_platform.is_cuda() | ||
| and not current_platform.is_rocm() |
Contributor
There was a problem hiding this comment.
To make the platform check more direct and robust against potential inconsistencies in is_cuda() behavior across environments, consider using current_platform.device_name == 'cuda'. This directly checks for the CUDA platform and is less prone to misinterpretation.
Suggested change
| current_platform.is_cuda() | |
| and not current_platform.is_rocm() | |
| current_platform.device_name == "cuda" |
… path Enable WNA16 (W4A16/W8A16) MoE quantization on ROCm by: - Adding "moe_wna16" to RocmPlatform.supported_quantization - Excluding ROCm from should_moe_wna16_use_cuda() so the Triton fallback kernel (invoke_fused_moe_wna16_triton_kernel) is used instead of the CUDA-only moe_wna16_gemm op The Triton WNA16 MoE kernel already works on ROCm. Linear layers within moe_wna16 models fall through to non-Marlin AWQ/GPTQ paths since check_marlin_supports_layer returns False on ROCm. This enables popular 4-bit quantized MoE models (Mixtral, DeepSeek, etc.) with GPTQ/AWQ quantization on AMD GPUs. Signed-off-by: Bruce Changlong Xu <brucechanglongxu@gmail.com>
2f63add to
080d293
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
moe_wna16(W4A16/W8A16 MoE quantization, used by GPTQ/AWQ-quantized Mixtral, DeepSeek, etc.) is blocked on ROCm by two issues:Not in
RocmPlatform.supported_quantization— platform verification rejects it outright.Even if you bypass that,
should_moe_wna16_use_cuda()infused_moe.pyreturnsTrueon ROCm because it checkscurrent_platform.is_cuda(), which returnsTruefor ROCm under the current platform model. This routes intoinvoke_fused_moe_wna16_cuda_kernel→ops.moe_wna16_gemm, a CUDA-only C++ op that isn't registered on ROCm builds. The Triton fallback path (invoke_fused_moe_wna16_triton_kernel) would work fine but never gets reached.The fix is two lines:
"moe_wna16"tosupported_quantizationinvllm/platforms/rocm.pyand not current_platform.is_rocm()toshould_moe_wna16_use_cuda()invllm/model_executor/layers/fused_moe/fused_moe.pyso the Triton kernel is used insteadThe linear layers within
moe_wna16models already handle ROCm correctly —check_marlin_supports_layer()inmarlin_utils.pyreturnsFalseon ROCm (line 213-214), soMoeWNA16Config.get_quant_method()falls through to the non-Marlin AWQ/GPTQ paths which have working ROCm support via Exllama/Conch.The Triton WNA16 MoE kernel is the same one used on CUDA when
should_moe_wna16_use_cuda()returns false (W8A16 case, or large batch sizes), so it's well-exercised in existing CI.