-
-
Notifications
You must be signed in to change notification settings - Fork 16.3k
[MoE Refactor] Split of DefaultMoERunner class #35326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4aeabf2
a4d3acb
5b7f133
fad7f33
ec88db3
76aff0a
4fab915
d8a7f91
3dec78f
e94b863
6cc5074
e8865e6
f83e0f5
88e80b9
781d4ea
3695016
053f66f
708dd2b
5748f7c
9123f15
04b430f
526db38
67bdab2
e9afbe6
453ab3d
48acc59
c067844
9f0e8d7
bc82978
3dc9d4f
12bda3d
8aaddea
bbaaca7
392f311
7d5adbe
f345165
bdefdf5
377acc8
14e58dc
392fb60
7b86f43
fd7a324
ea79ff7
b6ed07b
dd1f23a
acebc42
b08bf02
73a0356
9934e37
ed0ff6e
0d2b4dc
bae2080
9cfa719
4af2087
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,243 @@ | ||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||||||
|
|
||||||||||
| import torch | ||||||||||
|
|
||||||||||
| from vllm.forward_context import ( | ||||||||||
| get_forward_context, | ||||||||||
| ) | ||||||||||
| from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( | ||||||||||
| FusedMoEMethodBase, | ||||||||||
| ) | ||||||||||
| from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import MoERunnerBase | ||||||||||
| from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( | ||||||||||
| SharedExperts, | ||||||||||
| ) | ||||||||||
| from vllm.utils.math_utils import cdiv | ||||||||||
| from vllm.v1.worker.ubatching import dbo_current_ubatch_id | ||||||||||
| from vllm.v1.worker.workspace import current_workspace_manager | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ChunkingMoERunner(MoERunnerBase): | ||||||||||
| """ | ||||||||||
| MoE runner wrapper that adds chunked processing to any MoERunnerBase. | ||||||||||
|
|
||||||||||
| This runner wraps an inner MoERunnerBase and overrides _forward_impl to | ||||||||||
| process large batches by breaking them into smaller chunks. Each chunk | ||||||||||
| is delegated to the inner runner's _forward_impl, making chunking | ||||||||||
| composable with any runner implementation. | ||||||||||
|
|
||||||||||
| All MoERunnerBase state (moe_config, router, quant_method, etc.) is | ||||||||||
| transparently delegated to the inner runner via __getattr__. | ||||||||||
| ChunkingMoERunner only owns chunking-specific state: the pre-allocated | ||||||||||
| workspace buffers and the reduce_results override. | ||||||||||
|
|
||||||||||
| Key behaviors: | ||||||||||
| - Pre-allocates workspace tensors for CUDA graph compatibility | ||||||||||
| - Processes chunks via inner._forward_impl per chunk | ||||||||||
| - Never reduces results (reduce_results always returns False) | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| def __init__(self, inner: MoERunnerBase): | ||||||||||
| # Assert that _maybe_dispatch/_maybe_combine will be nops. | ||||||||||
| assert inner.moe_config.pcp_size == 1 | ||||||||||
|
|
||||||||||
| # Skip MoERunnerBase.__init__ — all state is delegated to inner | ||||||||||
| # via __getattr__. Only chunking-specific state lives here. | ||||||||||
| self._inner = inner | ||||||||||
|
|
||||||||||
| # Pre-allocated staging buffers. These need to exist ahead of time | ||||||||||
| # due to CUDA graph construction needing fixed buffer addresses. | ||||||||||
| self.batched_hidden_states, self.batched_router_logits = ( | ||||||||||
| self._init_dp_chunking() | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| def __getattr__(self, name): | ||||||||||
|
bnellnm marked this conversation as resolved.
|
||||||||||
| # Delegate attribute access to the inner runner. This is only | ||||||||||
| # called when normal lookup (instance __dict__, class MRO) fails, | ||||||||||
| # so ChunkingMoERunner's own attributes and methods take priority. | ||||||||||
| return getattr(self._inner, name) | ||||||||||
|
|
||||||||||
| @property | ||||||||||
| def shared_experts(self) -> SharedExperts | None: | ||||||||||
| return self._inner.shared_experts | ||||||||||
|
|
||||||||||
| # TODO(bnell): temporary hack, do not call this method. | ||||||||||
| def _replace_quant_method(self, quant_method: FusedMoEMethodBase): | ||||||||||
| self._inner._replace_quant_method(quant_method) | ||||||||||
| self.quant_method = quant_method | ||||||||||
|
|
||||||||||
| def is_internal_router(self) -> bool: | ||||||||||
| return self._inner.gate is not None | ||||||||||
|
|
||||||||||
| # Reducing results when chunking is handled by the MK finalize operations | ||||||||||
| # when DP chunking is enabled.. | ||||||||||
| # This will be removed by #35949 | ||||||||||
| @property | ||||||||||
| def reduce_results(self) -> bool: | ||||||||||
|
bnellnm marked this conversation as resolved.
|
||||||||||
| return False | ||||||||||
|
|
||||||||||
| def _init_dp_chunking(self) -> list[torch.Tensor]: | ||||||||||
| states_shape: tuple[int, ...] | ||||||||||
| logits_shape: tuple[int, ...] | ||||||||||
|
|
||||||||||
| moe = self.moe_config | ||||||||||
|
|
||||||||||
| if self.enable_dbo: | ||||||||||
| states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim) | ||||||||||
| logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts) | ||||||||||
| else: | ||||||||||
| states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim) | ||||||||||
| logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts) | ||||||||||
|
|
||||||||||
| # Does this need some kind of profiling run check like modular_kernel.py? | ||||||||||
| return current_workspace_manager().get_simultaneous( | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this I didn't see this in previous implementation
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. previously, it was
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the workspace is now shared among all layers. Previously there were separate buffers for each layer. |
||||||||||
| (states_shape, moe.in_dtype), | ||||||||||
| (logits_shape, moe.router_logits_dtype), | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| def _allocate_dp_chunking_outputs( | ||||||||||
| self, | ||||||||||
| hidden_states: torch.Tensor, | ||||||||||
| router_logits: torch.Tensor, | ||||||||||
| shared_experts_input: torch.Tensor | None, | ||||||||||
| ) -> tuple[torch.Tensor | None, torch.Tensor]: | ||||||||||
| # Assert the inputs are of the proper type and shape. | ||||||||||
| assert self.batched_hidden_states is not None | ||||||||||
| assert self.batched_router_logits is not None | ||||||||||
|
|
||||||||||
| assert self.batched_hidden_states.dtype == hidden_states.dtype, ( | ||||||||||
| f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}" | ||||||||||
| ) | ||||||||||
| assert self.batched_router_logits.dtype == router_logits.dtype, ( | ||||||||||
| f"{self.batched_router_logits.dtype} == {router_logits.dtype}" | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Check size compatibility. | ||||||||||
| assert self.batched_hidden_states.size(-1) == hidden_states.size(-1) | ||||||||||
| assert self.batched_router_logits.size(-1) == router_logits.size(-1) | ||||||||||
|
|
||||||||||
| final_fused_hidden_states = torch.empty_like(hidden_states) | ||||||||||
| if self.shared_experts is not None: | ||||||||||
| if shared_experts_input is not None: | ||||||||||
| final_shared_hidden_states = torch.empty_like(shared_experts_input) | ||||||||||
| else: | ||||||||||
| final_shared_hidden_states = torch.empty_like(hidden_states) | ||||||||||
| else: | ||||||||||
| final_shared_hidden_states = None | ||||||||||
|
|
||||||||||
| return final_shared_hidden_states, final_fused_hidden_states | ||||||||||
|
|
||||||||||
| def _slice_and_copy_input( | ||||||||||
| self, | ||||||||||
| out_slice: torch.Tensor, | ||||||||||
| orig: torch.Tensor | None, | ||||||||||
| start: int, | ||||||||||
| end: int, | ||||||||||
| ) -> torch.Tensor: | ||||||||||
| assert orig is not None | ||||||||||
| slice_size = end - start | ||||||||||
| orig_slice = orig[start:end, :] | ||||||||||
| if self.enable_dbo: | ||||||||||
| assert out_slice.dim() == 3 | ||||||||||
| batch_buffer_idx = dbo_current_ubatch_id() | ||||||||||
| out_slice = out_slice[batch_buffer_idx, :] | ||||||||||
|
|
||||||||||
| assert out_slice.size(0) >= slice_size | ||||||||||
| out_slice = out_slice[:slice_size, :] | ||||||||||
| out_slice.copy_(orig_slice, non_blocking=True) | ||||||||||
| return out_slice | ||||||||||
|
|
||||||||||
| def _forward_impl( | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it really needed to chunk the shared expert? I actually dont quite remember why this was required in the first place. I think that we could simplify things a lot if we only chunked the grouped expert Shouldnt be done in this PR, but something to consider for the follow up
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It makes things easier if it is chunked since the MK naturally chunks it if it is the one executing it. |
||||||||||
| self, | ||||||||||
| layer: torch.nn.Module, | ||||||||||
| hidden_states: torch.Tensor, | ||||||||||
| router_logits: torch.Tensor, | ||||||||||
| shared_experts_input: torch.Tensor | None, | ||||||||||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||||||||||
| final_shared_hidden_states, final_fused_hidden_states = ( | ||||||||||
| self._allocate_dp_chunking_outputs( | ||||||||||
| hidden_states, router_logits, shared_experts_input | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| ctx = get_forward_context() | ||||||||||
| # flashinfer_cutlass_kernels can handle: optional DP + TP/EP | ||||||||||
| max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu | ||||||||||
| moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens | ||||||||||
|
|
||||||||||
| # If the input to the MoE is sequence parallel then divide by sp_size | ||||||||||
| # to find the maximum number of tokens for any individual dispatcher. | ||||||||||
| if self.moe_config.is_sequence_parallel: | ||||||||||
| max_tokens_across_dispatchers = cdiv( | ||||||||||
| max_tokens_across_dispatchers, self.moe_config.sp_size | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| num_tokens = hidden_states.size(0) | ||||||||||
| for chunk_idx, chunk_start_ in enumerate( | ||||||||||
| range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank) | ||||||||||
| ): | ||||||||||
| chunk_start = chunk_start_ | ||||||||||
| chunk_end = min( | ||||||||||
| chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers | ||||||||||
| ) | ||||||||||
| # clamp start and end | ||||||||||
| chunk_start = min(chunk_start, num_tokens - 1) | ||||||||||
| chunk_end = min(chunk_end, num_tokens) | ||||||||||
|
Comment on lines
+185
to
+186
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The clamping logic for The clamping should be against
Suggested change
|
||||||||||
| chunk_sizes = ctx.dp_metadata.chunked_sizes( | ||||||||||
| self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx | ||||||||||
| ) | ||||||||||
| with chunk_sizes: | ||||||||||
| hidden_states_chunk = self._slice_and_copy_input( | ||||||||||
| self.batched_hidden_states, | ||||||||||
| hidden_states, | ||||||||||
| chunk_start, | ||||||||||
| chunk_end, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| router_logits_chunk = self._slice_and_copy_input( | ||||||||||
| self.batched_router_logits, | ||||||||||
| router_logits, | ||||||||||
| chunk_start, | ||||||||||
| chunk_end, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| shared_experts_input_chunk = ( | ||||||||||
| shared_experts_input[chunk_start:chunk_end, :] | ||||||||||
| if shared_experts_input is not None | ||||||||||
| else None | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Delegate per-chunk computation to the inner runner. | ||||||||||
| chunk_result = self._inner._forward_impl( | ||||||||||
| layer=layer, | ||||||||||
| hidden_states=hidden_states_chunk, | ||||||||||
| router_logits=router_logits_chunk, | ||||||||||
| shared_experts_input=shared_experts_input_chunk, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Store outputs | ||||||||||
| # TODO(bnell): document when chunk_start >= num_tokens | ||||||||||
| if chunk_start < num_tokens: | ||||||||||
| if self.shared_experts is not None: | ||||||||||
| assert isinstance(chunk_result, tuple) | ||||||||||
| shared_output_chunk, hidden_states_chunk = chunk_result | ||||||||||
| final_fused_hidden_states[chunk_start:chunk_end, :].copy_( | ||||||||||
| hidden_states_chunk, non_blocking=True | ||||||||||
| ) | ||||||||||
| assert shared_output_chunk is not None | ||||||||||
| assert final_shared_hidden_states is not None | ||||||||||
| final_shared_hidden_states[chunk_start:chunk_end, :].copy_( | ||||||||||
| shared_output_chunk, non_blocking=True | ||||||||||
| ) | ||||||||||
| else: | ||||||||||
| assert isinstance(chunk_result, torch.Tensor) | ||||||||||
| final_fused_hidden_states[chunk_start:chunk_end, :].copy_( | ||||||||||
| chunk_result, non_blocking=True | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| if self.shared_experts is None: | ||||||||||
| return final_fused_hidden_states | ||||||||||
| else: | ||||||||||
| assert final_shared_hidden_states is not None | ||||||||||
| return (final_shared_hidden_states, final_fused_hidden_states) | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just noting that its a bit odd to have both the inner and outer being
MoERunnerBaseUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this actually be a mixin ... just an idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree this is a bit awkward. I could probably tease apart MoERunnerBase so that this could inherit from something a little more abstract. I'm not sure I follow how a mixin would work here.