-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Support DeepEP for Kimi-k2-thinking through enabling gemm selection for compressed-tensor marlin wna16 #28574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
bdc02ca
0dcdd48
60c880b
1e18fc8
5a73a86
95225ce
3d9ab2f
72b3265
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,11 @@ | |
| from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( | ||
| is_valid_flashinfer_cutlass_fused_moe, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe | ||
| from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( | ||
| BatchedMarlinExperts, | ||
| MarlinExperts, | ||
| fused_marlin_moe, | ||
| ) | ||
| from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa | ||
| WNA16_SUPPORTED_BITS, | ||
| WNA16_SUPPORTED_TYPES_MAP, | ||
|
|
@@ -1562,7 +1566,36 @@ | |
| def get_fused_moe_quant_config( | ||
| self, layer: torch.nn.Module | ||
| ) -> FusedMoEQuantConfig | None: | ||
| return None | ||
| if self.num_bits != 4: | ||
| return None | ||
| return int4_w4a16_moe_quant_config( | ||
| w1_scale=layer.w13_weight_scale, | ||
| w2_scale=layer.w2_weight_scale, | ||
| w1_zp=None, | ||
| w2_zp=None, | ||
| block_shape=[0, self.group_size], | ||
| ) | ||
|
|
||
| def select_gemm_impl( | ||
| self, | ||
| prepare_finalize: mk.FusedMoEPrepareAndFinalize, | ||
| layer: torch.nn.Module, | ||
| ) -> mk.FusedMoEPermuteExpertsUnpermute: | ||
luccafong marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| layer.w13_weight = layer.w13_weight_packed | ||
| layer.w2_weight = layer.w2_weight_packed | ||
| assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]]) | ||
| assert self.moe_quant_config is not None | ||
| if ( | ||
| prepare_finalize.activation_format | ||
| == mk.FusedMoEActivationFormat.BatchedExperts | ||
| ): | ||
| return BatchedMarlinExperts( | ||
| max_num_tokens=prepare_finalize.max_num_tokens_per_rank(), | ||
| num_dispatchers=prepare_finalize.num_dispatchers(), | ||
| quant_config=self.moe_quant_config, | ||
|
Comment on lines
1578
to
1614
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new Useful? React with 👍 / 👎.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @luccafong - This call out seems reasonable. Looks like you'd need to plumb through
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, seems we will need to touch the base method signature of
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me try the later approach
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved in 1e18fc8 |
||
| ) | ||
| else: | ||
| return MarlinExperts(self.moe_quant_config) | ||
|
|
||
| def apply( | ||
| self, | ||
|
|
@@ -1573,7 +1606,7 @@ | |
| renormalize: bool, | ||
| use_grouped_topk: bool = False, | ||
| topk_group: int | None = None, | ||
| num_expert_group: int | None = None, | ||
| global_num_experts: int = -1, | ||
| expert_map: torch.Tensor | None = None, | ||
| custom_routing_function: Callable | None = None, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.