Skip to content

[MoE Refactoring][Bugfix]Wrap WNA16 Triton kernel into mk and change compressed tensor kernel selection#31752

Merged
robertgshaw2-redhat merged 6 commits intovllm-project:mainfrom
zyongye:wna16_triton_fix
Jan 9, 2026
Merged

[MoE Refactoring][Bugfix]Wrap WNA16 Triton kernel into mk and change compressed tensor kernel selection#31752
robertgshaw2-redhat merged 6 commits intovllm-project:mainfrom
zyongye:wna16_triton_fix

Conversation

@zyongye
Copy link
Copy Markdown
Member

@zyongye zyongye commented Jan 5, 2026

Refactor WNA16 Triton kernel into modular kernel format.
Previously, we factor out the WNA16 from TritonExperts and that caused WNA16 lora module can't select the correct kernel (see issue). This PR fixes the problem.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the WNA16 Triton kernel into a modular kernel class TritonWNA16Experts and fixes a bug in the compressed tensor kernel selection to use this new class. The changes are logical, but there is significant code duplication in the new TritonWNA16Experts class, which inherits from TritonExperts. I've provided suggestions to refactor this to improve maintainability and adhere to the DRY principle. Specifically, the apply method is almost a complete copy of its parent's implementation, and the moe_sum method is also duplicated.

Comment on lines +2443 to +2586

E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids
)

if global_num_experts == -1:
global_num_experts = E

config = try_get_optimal_moe_config(
w1.size(),
w2.size(),
top_k_num,
self.quant_config.config_name(hidden_states.dtype),
num_tokens,
block_shape=self.block_shape,
)

if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif (
hidden_states.dtype == torch.float8_e4m3fn
or hidden_states.dtype == torch.float8_e4m3fnuz
):
compute_type = tl.bfloat16
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")

# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, N // 2)
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))

sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
)

invoke_fused_moe_wna16_triton_kernel(
hidden_states,
w1,
intermediate_cache1,
self.w1_scale,
self.quant_config.w1_zp,
None, # topk_weights
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False, # mul_routed_weights
top_k_num,
config,
compute_type=compute_type,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
block_shape=self.block_shape,
)

self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
)

a2q_scale: torch.Tensor | None = None

qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2,
a2_scale,
self.quant_dtype,
self.per_act_token_quant,
self.block_shape,
)

invoke_fused_moe_wna16_triton_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
self.w2_scale,
self.quant_config.w2_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
block_shape=self.block_shape,
)

# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)

def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The TritonWNA16Experts class, particularly the apply method, contains a large amount of duplicated code from its parent class TritonExperts. This significantly impacts maintainability, as any changes to the base logic would need to be manually synchronized in this class as well.

To adhere to the Don't Repeat Yourself (DRY) principle, I recommend refactoring this. A possible approach is to extract the common logic into a protected base method in TritonExperts and have TritonWNA16Experts override only the parts that differ, such as the kernel invocation logic.

For example, you could introduce helper methods in TritonExperts that TritonWNA16Experts can override:

class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
    # ...

    def _get_kernel_fn(self):
        return invoke_fused_moe_triton_kernel

    def _get_first_gemm_kwargs(self, ...):
        # returns kwargs for the first GEMM
        return { ... }

    def _get_second_gemm_kwargs(self, ...):
        # returns kwargs for the second GEMM
        return { ... }

    def apply(self, ...):
        # ... common setup ...

        kernel_fn = self._get_kernel_fn()
        first_gemm_kwargs = self._get_first_gemm_kwargs(...)
        kernel_fn(**first_gemm_kwargs)

        # ... common activation and quantization ...

        second_gemm_kwargs = self._get_second_gemm_kwargs(...)
        kernel_fn(**second_gemm_kwargs)

        # ... common finalization ...

class TritonWNA16Experts(TritonExperts):
    def _get_kernel_fn(self):
        return invoke_fused_moe_wna16_triton_kernel

    def _get_first_gemm_kwargs(self, ...):
        # returns kwargs for the first WNA16 GEMM
        return { ... }

    def _get_second_gemm_kwargs(self, ...):
        # returns kwargs for the second WNA16 GEMM
        return { ... }

    # No need to override apply() if all differences are in helper methods.

This refactoring would make the code much cleaner, more maintainable, and less prone to bugs from unsynchronized changes.

Comment on lines +2584 to +2585
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This moe_sum method is identical to the one in the parent class TritonExperts. It can be removed to avoid code duplication and improve maintainability.

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 5, 2026

Hi @zyongye, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

1 similar comment
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 5, 2026

Hi @zyongye, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@JartX
Copy link
Copy Markdown
Contributor

JartX commented Jan 5, 2026

@zyongye works, problem solved for me :)

Copy link
Copy Markdown
Collaborator

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok to me. Can you update tests/kernels/moe/modular_kernel_tools/mk_objects.py so that the new experts class gets tested?

@zyongye zyongye requested a review from WoosukKwon as a code owner January 6, 2026 18:29
@zyongye zyongye added ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm labels Jan 6, 2026
@zyongye zyongye moved this to In progress in MoE Refactor Jan 6, 2026
register_experts(
TritonWNA16Experts,
standard_format,
common_float_and_int_types,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this kernel only work for 16-bit float activation types?

Robert Shaw and others added 6 commits January 7, 2026 19:46
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This reverts commit fa893c1932d01cec5449f52ec55413f07932549e.

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye
Copy link
Copy Markdown
Member Author

zyongye commented Jan 7, 2026

After discussion with @bnellnm , we decided to defer adding tests to the other PR sincethe current test structure lacks int4 test infra.

@robertgshaw2-redhat robertgshaw2-redhat moved this from In progress to In review in MoE Refactor Jan 8, 2026
@robertgshaw2-redhat robertgshaw2-redhat merged commit d62cfe5 into vllm-project:main Jan 9, 2026
56 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in MoE Refactor Jan 9, 2026
@zyongye zyongye deleted the wna16_triton_fix branch January 9, 2026 00:02
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
…compressed tensor kernel selection (vllm-project#31752)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…compressed tensor kernel selection (vllm-project#31752)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…compressed tensor kernel selection (vllm-project#31752)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…compressed tensor kernel selection (vllm-project#31752)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants