[MoE Refactoring][Bugfix]Wrap WNA16 Triton kernel into mk and change compressed tensor kernel selection#31752
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the WNA16 Triton kernel into a modular kernel class TritonWNA16Experts and fixes a bug in the compressed tensor kernel selection to use this new class. The changes are logical, but there is significant code duplication in the new TritonWNA16Experts class, which inherits from TritonExperts. I've provided suggestions to refactor this to improve maintainability and adhere to the DRY principle. Specifically, the apply method is almost a complete copy of its parent's implementation, and the moe_sum method is also duplicated.
|
|
||
| E, num_tokens, N, K, top_k_num = self.moe_problem_size( | ||
| hidden_states, w1, w2, topk_ids | ||
| ) | ||
|
|
||
| if global_num_experts == -1: | ||
| global_num_experts = E | ||
|
|
||
| config = try_get_optimal_moe_config( | ||
| w1.size(), | ||
| w2.size(), | ||
| top_k_num, | ||
| self.quant_config.config_name(hidden_states.dtype), | ||
| num_tokens, | ||
| block_shape=self.block_shape, | ||
| ) | ||
|
|
||
| if hidden_states.dtype == torch.bfloat16: | ||
| compute_type = tl.bfloat16 | ||
| elif hidden_states.dtype == torch.float16: | ||
| compute_type = tl.float16 | ||
| elif hidden_states.dtype == torch.float32: | ||
| compute_type = tl.float32 | ||
| elif ( | ||
| hidden_states.dtype == torch.float8_e4m3fn | ||
| or hidden_states.dtype == torch.float8_e4m3fnuz | ||
| ): | ||
| compute_type = tl.bfloat16 | ||
| else: | ||
| raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") | ||
|
|
||
| # Note that the output tensor might be in workspace1 | ||
| intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) | ||
| intermediate_cache2 = _resize_cache( | ||
| workspace13, (num_tokens * top_k_num, N // 2) | ||
| ) | ||
| intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) | ||
|
|
||
| sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( | ||
| topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map | ||
| ) | ||
|
|
||
| invoke_fused_moe_wna16_triton_kernel( | ||
| hidden_states, | ||
| w1, | ||
| intermediate_cache1, | ||
| self.w1_scale, | ||
| self.quant_config.w1_zp, | ||
| None, # topk_weights | ||
| sorted_token_ids, | ||
| expert_ids, | ||
| num_tokens_post_padded, | ||
| False, # mul_routed_weights | ||
| top_k_num, | ||
| config, | ||
| compute_type=compute_type, | ||
| use_int8_w8a16=self.quant_config.use_int8_w8a16, | ||
| use_int4_w4a16=self.quant_config.use_int4_w4a16, | ||
| block_shape=self.block_shape, | ||
| ) | ||
|
|
||
| self.activation( | ||
| activation, intermediate_cache2, intermediate_cache1.view(-1, N) | ||
| ) | ||
|
|
||
| a2q_scale: torch.Tensor | None = None | ||
|
|
||
| qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( | ||
| intermediate_cache2, | ||
| a2_scale, | ||
| self.quant_dtype, | ||
| self.per_act_token_quant, | ||
| self.block_shape, | ||
| ) | ||
|
|
||
| invoke_fused_moe_wna16_triton_kernel( | ||
| qintermediate_cache2, | ||
| w2, | ||
| intermediate_cache3, | ||
| self.w2_scale, | ||
| self.quant_config.w2_zp, | ||
| topk_weights, | ||
| sorted_token_ids, | ||
| expert_ids, | ||
| num_tokens_post_padded, | ||
| not apply_router_weight_on_input, | ||
| 1, | ||
| config, | ||
| compute_type=compute_type, | ||
| use_int8_w8a16=self.quant_config.use_int8_w8a16, | ||
| use_int4_w4a16=self.quant_config.use_int4_w4a16, | ||
| block_shape=self.block_shape, | ||
| ) | ||
|
|
||
| # separate function is required for MoE + LoRA | ||
| self.moe_sum(intermediate_cache3, output) | ||
|
|
||
| def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: | ||
| ops.moe_sum(input, output) | ||
|
|
There was a problem hiding this comment.
The TritonWNA16Experts class, particularly the apply method, contains a large amount of duplicated code from its parent class TritonExperts. This significantly impacts maintainability, as any changes to the base logic would need to be manually synchronized in this class as well.
To adhere to the Don't Repeat Yourself (DRY) principle, I recommend refactoring this. A possible approach is to extract the common logic into a protected base method in TritonExperts and have TritonWNA16Experts override only the parts that differ, such as the kernel invocation logic.
For example, you could introduce helper methods in TritonExperts that TritonWNA16Experts can override:
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# ...
def _get_kernel_fn(self):
return invoke_fused_moe_triton_kernel
def _get_first_gemm_kwargs(self, ...):
# returns kwargs for the first GEMM
return { ... }
def _get_second_gemm_kwargs(self, ...):
# returns kwargs for the second GEMM
return { ... }
def apply(self, ...):
# ... common setup ...
kernel_fn = self._get_kernel_fn()
first_gemm_kwargs = self._get_first_gemm_kwargs(...)
kernel_fn(**first_gemm_kwargs)
# ... common activation and quantization ...
second_gemm_kwargs = self._get_second_gemm_kwargs(...)
kernel_fn(**second_gemm_kwargs)
# ... common finalization ...
class TritonWNA16Experts(TritonExperts):
def _get_kernel_fn(self):
return invoke_fused_moe_wna16_triton_kernel
def _get_first_gemm_kwargs(self, ...):
# returns kwargs for the first WNA16 GEMM
return { ... }
def _get_second_gemm_kwargs(self, ...):
# returns kwargs for the second WNA16 GEMM
return { ... }
# No need to override apply() if all differences are in helper methods.This refactoring would make the code much cleaner, more maintainable, and less prone to bugs from unsynchronized changes.
| def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: | ||
| ops.moe_sum(input, output) |
|
Hi @zyongye, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
1 similar comment
|
Hi @zyongye, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
@zyongye works, problem solved for me :) |
bnellnm
left a comment
There was a problem hiding this comment.
Looks ok to me. Can you update tests/kernels/moe/modular_kernel_tools/mk_objects.py so that the new experts class gets tested?
| register_experts( | ||
| TritonWNA16Experts, | ||
| standard_format, | ||
| common_float_and_int_types, |
There was a problem hiding this comment.
Doesn't this kernel only work for 16-bit float activation types?
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This reverts commit fa893c1932d01cec5449f52ec55413f07932549e. Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
cc64c7b to
b7e12f5
Compare
|
After discussion with @bnellnm , we decided to defer adding tests to the other PR sincethe current test structure lacks int4 test infra. |
…compressed tensor kernel selection (vllm-project#31752) Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
…compressed tensor kernel selection (vllm-project#31752) Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
…compressed tensor kernel selection (vllm-project#31752) Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
…compressed tensor kernel selection (vllm-project#31752) Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
Refactor WNA16 Triton kernel into modular kernel format.
Previously, we factor out the WNA16 from
TritonExpertsand that caused WNA16 lora module can't select the correct kernel (see issue). This PR fixes the problem.