-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[gpt-oss] triton kernel mxfp4 #22421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for mxfp4 quantization on Hopper GPUs by integrating a new Triton kernel for MoE layers. The changes include adding the kernel wrappers, modifying the mxfp4 quantization path to use it, and adding corresponding tests. The implementation looks solid, but I have two high-level concerns. First, the number of warps for the Triton kernel is configured statically based on an environment variable, which might not be optimal or correct for dynamic batch sizes at runtime. Second, a utility function modifies a global configuration flag, which is a risky pattern that could lead to hard-to-debug side effects. Addressing these points would improve the robustness and maintainability of this new feature.
| # FIXME warp need to be adjusted based on batch size | ||
| # only apply to batched mode | ||
| if self.moe.use_ep: | ||
| num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 | ||
| else: | ||
| num_warps = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FIXME comment on line 301 indicates that num_warps should be adjusted based on the batch size. The current implementation determines num_warps based on the static environment variable VLLM_MOE_DP_CHUNK_SIZE, which may not reflect the dynamic batch size at runtime. This static configuration could lead to suboptimal performance or potential correctness issues if the Triton kernel has strict requirements for num_warps based on the input size. This value is used during weight loading to swizzle the weights, so it cannot be changed dynamically per batch without re-swizzling. This suggests a potential design issue that should be addressed for robust performance and correctness.
| if current_platform.is_cuda() and \ | ||
| current_platform.is_device_capability(100): | ||
| constraints = { | ||
| "is_persistent": True, | ||
| "epilogue_subtile": 1, | ||
| } | ||
| opt_flags.update_opt_flags_constraints(constraints) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function _swizzle_mxfp4 modifies a global state by calling opt_flags.update_opt_flags_constraints(constraints). Modifying global state within a utility function is a dangerous pattern as it can introduce non-local side effects that are difficult to debug, especially in a system that might handle multiple models or requests concurrently. This could cause issues if different models or layers have conflicting requirements for these optimization flags. It would be safer to manage this global state with more care, for example, by using a context manager to set and restore the flags, or by passing constraints as parameters to the underlying kernel if the API supports it.
Signed-off-by: <[email protected]> Signed-off-by: Yongye Zhu <[email protected]>
Signed-off-by: Yongye Zhu <[email protected]>
| def has_triton_kernels() -> bool: | ||
| """Whether the optional `triton_kernels` package is available.""" | ||
|
|
||
| return _has_module("triton_kernels") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: How can I install this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to install directly from triton repo
uv pip install triton/python/triton_kernels --no-deps
There's no PyPI wheel yet
|
hmm, this broke the trunk |
|
Are you running You need to install git clone https://github.com/triton-lang/triton
uv pip install triton/python/triton_kernels --no-deps |
|
Pushed a fix #22529 |
|
Just FYI that the error shows up on llama4 benchmark run https://github.com/pytorch/pytorch-integration-testing/actions/runs/16834994069/job/47692144587#step:14:3962, so it's other models too |
|
yea I was running deepseek. The code path is shared. Thanks for the quick fix! |
Signed-off-by: <[email protected]> Signed-off-by: Yongye Zhu <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: <[email protected]> Signed-off-by: Yongye Zhu <[email protected]> Signed-off-by: Noam Gat <[email protected]>
|
Hi @zyongye , can we use that kernel on Blackwell? If so, could you provide the Triton commit? I encountered the following issue when running UT locally. |
|
Hi @yiliu30 for Blackwell SM100 we have kernels from flashinfer available, please see the recipe for details https://docs.vllm.ai/projects/recipes/en/latest/OpenAI/GPT-OSS.html#b200 |
Signed-off-by: <[email protected]> Signed-off-by: Yongye Zhu <[email protected]> Signed-off-by: Paul Pak <[email protected]>
Signed-off-by: <[email protected]> Signed-off-by: Yongye Zhu <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
Signed-off-by: <[email protected]> Signed-off-by: Yongye Zhu <[email protected]>
Signed-off-by: <[email protected]> Signed-off-by: Yongye Zhu <[email protected]>
Signed-off-by: <[email protected]> Signed-off-by: Yongye Zhu <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
Signed-off-by: <[email protected]> Signed-off-by: Yongye Zhu <[email protected]>
| quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4), | ||
| value_layout, **value_layout_opts) | ||
| scale = convert_layout(wrap_torch_tensor(scale), scale_layout, | ||
| **scale_layout_opts) | ||
| return quant_tensor, InFlexData(), scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it safe to unwrap from triton_kernels.tensor.Tensor from here? Could we avoid it in the first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is util function from triton_kernels. It is designed to take triton_kernels.Tensor instead of torch.Tensor.
| del layer.w2_weight | ||
| layer.w13_weight = None | ||
| layer.w2_weight = None | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
link
Need nightly torch and triton main to work.
Don't merge. want for accuracy test