-
Notifications
You must be signed in to change notification settings - Fork 828
CuteDSL MoE fix redundant output buffer zeroing #2811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -58,7 +58,6 @@ | |||||||||||||||
| from .moe_utils import ( | ||||||||||||||||
| allocate_moe_sort_buffers, | ||||||||||||||||
| get_max_num_permuted_tokens, | ||||||||||||||||
| moe_output_memset, | ||||||||||||||||
| moe_sort, | ||||||||||||||||
| ) | ||||||||||||||||
| from .blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion import ( | ||||||||||||||||
|
|
@@ -143,7 +142,7 @@ def _moe_core_impl( | |||||||||||||||
| This function handles: | ||||||||||||||||
| 1. moe_sort: Token routing computation | ||||||||||||||||
| 2. GEMM1 + SwiGLU: First projection with activation | ||||||||||||||||
| 3. Async moe_output_memset: Zero output buffer (overlapped with GEMM1) | ||||||||||||||||
| 3. Async output zero: Zero output buffer (overlapped with GEMM1) | ||||||||||||||||
| 4. GEMM2 + Finalize: Second projection with atomic scatter | ||||||||||||||||
|
|
||||||||||||||||
| Args: | ||||||||||||||||
|
|
@@ -183,13 +182,20 @@ def _moe_core_impl( | |||||||||||||||
| num_tokens = token_selected_experts.size(0) | ||||||||||||||||
| hidden_size = w2_weight.size(1) | ||||||||||||||||
|
|
||||||||||||||||
| # Allocate output if not provided | ||||||||||||||||
| # Allocate output if not provided. The caller (wrapper or functional | ||||||||||||||||
| # API) should pass a [:num_tokens] slice of the pre-allocated buffer | ||||||||||||||||
| # when using CUDA graphs. The buffer is zeroed in Step 3 below. | ||||||||||||||||
| if moe_output is None: | ||||||||||||||||
| moe_output = torch.empty( | ||||||||||||||||
| (num_tokens, hidden_size), | ||||||||||||||||
| dtype=output_dtype, | ||||||||||||||||
| device=x.device, | ||||||||||||||||
| ) | ||||||||||||||||
| else: | ||||||||||||||||
| assert moe_output.size(0) == num_tokens, ( | ||||||||||||||||
| f"moe_output must be sliced to num_tokens rows before calling " | ||||||||||||||||
| f"_moe_core_impl (got {moe_output.size(0)}, expected {num_tokens})" | ||||||||||||||||
| ) | ||||||||||||||||
|
Comment on lines
+195
to
+198
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assertion is a valuable safeguard. To improve debuggability when this assertion fails, I recommend enhancing the error message to include the actual and expected tensor sizes. This provides immediate, actionable context to the developer, reducing debugging time.
Suggested change
Comment on lines
+185
to
+198
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't require an exact-row The new assert breaks callers that reuse a larger pre-allocated output buffer, even though the optimization only needs a Proposed fix if moe_output is None:
moe_output = torch.empty(
(num_tokens, hidden_size),
dtype=output_dtype,
device=x.device,
)
else:
- assert moe_output.size(0) == num_tokens, (
- "moe_output must be sliced to num_tokens rows before calling _moe_core_impl"
- )
+ if moe_output.size(0) < num_tokens or moe_output.size(1) != hidden_size:
+ raise ValueError(
+ "moe_output must have shape [>= num_tokens, hidden_size]"
+ )
+ moe_output = moe_output[:num_tokens]
+ if not moe_output.is_contiguous():
+ raise ValueError("moe_output[:num_tokens] must be contiguous")π€ Prompt for AI Agents |
||||||||||||||||
|
|
||||||||||||||||
| # Get stream resources if using async memset | ||||||||||||||||
| if use_async_memset: | ||||||||||||||||
|
|
@@ -246,28 +252,23 @@ def _moe_core_impl( | |||||||||||||||
| ) | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| # Step 3: Async moe_output_memset on auxiliary stream | ||||||||||||||||
| # Step 3: Zero the active output slice before GEMM2 finalize. | ||||||||||||||||
| # Finalize uses atomic scatter-add into `moe_output`, so it must start | ||||||||||||||||
| # from zero each call. We zero only the active slice, not the full | ||||||||||||||||
| # preallocated buffer. We do not use `moe_output_memset` here because | ||||||||||||||||
| # FlashInfer's port always invokes the sparse kernel, missing the | ||||||||||||||||
| # TRT-LLM dispatch that falls back to cudaMemsetAsync (dense zero) | ||||||||||||||||
| # when !enable_alltoall || ep_size <= top_k. A dense zero of the | ||||||||||||||||
| # active slice is correct for all configurations. | ||||||||||||||||
nv-yunzheq marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||
| # TODO: add the TRTLLM all-to-all and `moe_output_memset` behavior | ||||||||||||||||
| if use_async_memset: | ||||||||||||||||
| max_num_permuted_tokens = get_max_num_permuted_tokens( | ||||||||||||||||
| num_tokens, top_k, num_local_experts, tile_size | ||||||||||||||||
| ) | ||||||||||||||||
| with torch.cuda.stream(aux_stream): | ||||||||||||||||
| main_event.wait() | ||||||||||||||||
| moe_output_memset( | ||||||||||||||||
| output=moe_output, | ||||||||||||||||
| tile_idx_to_mn_limit=tile_idx_to_mn_limit, | ||||||||||||||||
| expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, | ||||||||||||||||
| permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, | ||||||||||||||||
| num_non_exiting_tiles=num_non_exiting_tiles, | ||||||||||||||||
| max_num_permuted_tokens=max_num_permuted_tokens, | ||||||||||||||||
| top_k=top_k, | ||||||||||||||||
| tile_size=tile_size, | ||||||||||||||||
| ) | ||||||||||||||||
| moe_output.zero_() | ||||||||||||||||
| memset_event.record() | ||||||||||||||||
| memset_event.wait() | ||||||||||||||||
| else: | ||||||||||||||||
| # Simple zero without async | ||||||||||||||||
| moe_output[:num_tokens].zero_() | ||||||||||||||||
| moe_output.zero_() | ||||||||||||||||
|
|
||||||||||||||||
| # Step 4: GEMM2 + Finalize | ||||||||||||||||
| blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( | ||||||||||||||||
|
|
@@ -431,7 +432,8 @@ def _allocate_buffers(self) -> None: | |||||||||||||||
| (scale_size,), dtype=torch.uint8, device=self.device | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| # Final output | ||||||||||||||||
| # Final output β sliced to [:num_tokens] before each forward pass, | ||||||||||||||||
| # then zeroed before GEMM2 finalize, typically on aux_stream. | ||||||||||||||||
| self._moe_output = torch.empty( | ||||||||||||||||
| (self.max_num_tokens, self.hidden_size), | ||||||||||||||||
| dtype=self.output_dtype, | ||||||||||||||||
|
|
@@ -497,7 +499,8 @@ def _forward_with_tactic( | |||||||||||||||
| gemm1_out_scale=self._gemm1_output_scale if self.use_cuda_graph else None, | ||||||||||||||||
| moe_output=moe_output | ||||||||||||||||
| if moe_output is not None | ||||||||||||||||
| else (self._moe_output if self.use_cuda_graph else None), | ||||||||||||||||
| # Slice the CUDA-graph buffer to the active batch. | ||||||||||||||||
| else (self._moe_output[: x.shape[0]] if self.use_cuda_graph else None), | ||||||||||||||||
| aux_stream=self._aux_stream, | ||||||||||||||||
| main_event=self._main_event, | ||||||||||||||||
| memset_event=self._memset_event, | ||||||||||||||||
|
|
@@ -550,9 +553,10 @@ def run( | |||||||||||||||
| f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})" | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| # Allocate output buffer if not using pre-allocated one | ||||||||||||||||
| # Slice the pre-allocated buffer to the active batch so that | ||||||||||||||||
| # _moe_core_impl only zeros num_tokens rows, not max_num_tokens. | ||||||||||||||||
| if self.use_cuda_graph: | ||||||||||||||||
| moe_output = self._moe_output | ||||||||||||||||
| moe_output = self._moe_output[:num_tokens] | ||||||||||||||||
| else: | ||||||||||||||||
| moe_output = torch.empty( | ||||||||||||||||
| (num_tokens, self.hidden_size), | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This turns the public
out=path into a silent accumulation trap.blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4is still aflashinfer_api, so removing the internal zero on caller-supplied buffers breaks existingout=call sites with wrong answers rather than a loud failure. Please keep the zero-free fast path internal, or gate it behind an explicit opt-in, and preserve overwrite semantics on the public entry point.Also applies to: 321-325, 410-420
π€ Prompt for AI Agents