-
-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[Perf] Do FP4 quant before All gather on flashinfer trtllmgen MOE #30014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bcf3413
193328c
d8127ce
4d15aae
dd5ae90
15f9155
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -318,17 +318,23 @@ def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None): | |
|
|
||
| return output_list | ||
|
|
||
| def dispatch( | ||
| def dispatch( # type: ignore[override] | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| router_logits: torch.Tensor, | ||
| is_sequence_parallel: bool = False, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| extra_tensors: list[torch.Tensor] | None = None, | ||
| ) -> ( | ||
| tuple[torch.Tensor, torch.Tensor] | ||
| | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | ||
| ): | ||
| assert self.all2all_manager is not None | ||
| hidden_states, router_logits = self.all2all_manager.dispatch( | ||
| hidden_states, router_logits, is_sequence_parallel | ||
| return self.all2all_manager.dispatch( | ||
| hidden_states, | ||
| router_logits, | ||
| is_sequence_parallel, | ||
| extra_tensors, # type: ignore[call-arg] | ||
|
Comment on lines
+332
to
+336
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new pre-quantized path now calls Useful? React with 👍 / 👎. |
||
| ) | ||
| return hidden_states, router_logits | ||
|
|
||
| def combine( | ||
| self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,7 @@ | |
| is_flashinfer_supporting_global_sf, | ||
| ) | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe | ||
| from vllm.utils.math_utils import cdiv, round_up | ||
| from vllm.utils.torch_utils import ( | ||
| aux_stream, | ||
|
|
@@ -1933,10 +1934,46 @@ def forward_impl( | |
| ) | ||
|
|
||
| with sp_ctx: | ||
| extra_tensors = None | ||
| if do_naive_dispatch_combine: | ||
| hidden_states_combined, router_logits = get_ep_group().dispatch( | ||
| hidden_states, router_logits, self.is_sequence_parallel | ||
| # Avoid circular import | ||
| from vllm.model_executor.layers.quantization.modelopt import ( | ||
| ModelOptNvFp4FusedMoE, | ||
| ) | ||
|
|
||
| post_quant_allgather = ( | ||
| has_flashinfer_trtllm_fused_moe() | ||
| and self.quant_method is not None | ||
| and self.dp_size > 1 | ||
| and self.use_ep | ||
| and isinstance(self.quant_method, ModelOptNvFp4FusedMoE) | ||
| ) | ||
| if post_quant_allgather: | ||
| hidden_states_to_dispatch, extra_tensors = ( | ||
| self.quant_method.prepare_dp_allgather_tensor( | ||
| self, hidden_states, router_logits | ||
| ) | ||
| ) | ||
| else: | ||
| hidden_states_to_dispatch = hidden_states | ||
|
|
||
| dispatch_res = get_ep_group().dispatch( | ||
| hidden_states_to_dispatch, | ||
| router_logits, | ||
|
Comment on lines
+1960
to
+1962
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the DP dispatch path the call to Useful? React with 👍 / 👎. |
||
| self.is_sequence_parallel, | ||
| extra_tensors=extra_tensors, | ||
| ) | ||
| if extra_tensors is not None: | ||
| hidden_states_combined, router_logits, extra_tensors_combined = ( | ||
| dispatch_res | ||
| ) | ||
| hidden_states_combined = ( | ||
| hidden_states_combined, | ||
| extra_tensors_combined[0], | ||
| ) | ||
| else: | ||
| hidden_states_combined, router_logits = dispatch_res | ||
|
|
||
| # Run shared experts before matrix multiply. | ||
| # because matrix multiply maybe modify the hidden_states. | ||
| if has_separate_shared_experts and not use_shared_experts_stream: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.