diff --git a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index 436e8da3dc..d496e85e69 100644 --- a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -298,7 +298,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1. token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16 out: Optional output tensor, shape (seq_len, n). Created if None. - This tensor is used for atomic accumulation, so it should be zero-initialized. + This tensor is used for atomic accumulation. If `out` is + provided, it must already be zero-initialized by the caller. + If `out` is None, this function allocates a zero-initialized + output tensor. Passing a non-zeroed `out` buffer will silently + produce incorrect results. ab_dtype: Data type for A and B matrices. Default: "float4_e2m1fn" sf_dtype: Data type for scale factors. Default: "float8_e4m3fn" out_dtype: Data type for output matrix. Default: "bfloat16" @@ -314,6 +318,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( Notes: - The output tensor is modified in-place using atomic adds for scatter-reduction. + - When out is provided it is NOT zeroed internally; the caller + must ensure the buffer is zeroed before each invocation. + In the main CuteDSL MoE path, _moe_core_impl handles this by + zeroing the active output slice before GEMM2, typically on an + auxiliary stream overlapped with GEMM1. - Call create_finalize_fusion_tensors() to create permuted_idx_to_expanded_idx and token_final_scales. - Requires SM100 (Blackwell) GPU architecture - The finalize fusion eliminates the need for a separate moe_unpermute kernel @@ -398,16 +407,17 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( f"cluster_shape_mn={cluster_shape_mn}, shape=({permuted_m}, {n}, {k}, {num_experts})" ) - # Create output tensor if not provided (zero-initialized for atomic adds) + # Create output tensor if not provided (zero-initialized for atomic adds). + # If out is provided, the caller is responsible for zeroing it before + # this call. The GEMM2 epilogue uses atomic scatter-add + # (out[token_idx] += ...), so any non-zero residual would corrupt + # results. if out is None: out = torch.zeros( (seq_len, n), dtype=cutlass_to_torch_dtype(out_dtype_cutlass), device=a.device, ) - else: - # Ensure output is zero for proper accumulation - out.zero_() # Get SM count if sm_count is None: diff --git a/flashinfer/fused_moe/cute_dsl/fused_moe.py b/flashinfer/fused_moe/cute_dsl/fused_moe.py index 4580e0704a..ec74bdf43e 100644 --- a/flashinfer/fused_moe/cute_dsl/fused_moe.py +++ b/flashinfer/fused_moe/cute_dsl/fused_moe.py @@ -58,7 +58,6 @@ from .moe_utils import ( allocate_moe_sort_buffers, get_max_num_permuted_tokens, - moe_output_memset, moe_sort, ) from .blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion import ( @@ -143,7 +142,7 @@ def _moe_core_impl( This function handles: 1. moe_sort: Token routing computation 2. GEMM1 + SwiGLU: First projection with activation - 3. Async moe_output_memset: Zero output buffer (overlapped with GEMM1) + 3. Async output zero: Zero output buffer (overlapped with GEMM1) 4. GEMM2 + Finalize: Second projection with atomic scatter Args: @@ -183,13 +182,20 @@ def _moe_core_impl( num_tokens = token_selected_experts.size(0) hidden_size = w2_weight.size(1) - # Allocate output if not provided + # Allocate output if not provided. The caller (wrapper or functional + # API) should pass a [:num_tokens] slice of the pre-allocated buffer + # when using CUDA graphs. The buffer is zeroed in Step 3 below. if moe_output is None: moe_output = torch.empty( (num_tokens, hidden_size), dtype=output_dtype, device=x.device, ) + else: + assert moe_output.size(0) == num_tokens, ( + f"moe_output must be sliced to num_tokens rows before calling " + f"_moe_core_impl (got {moe_output.size(0)}, expected {num_tokens})" + ) # Get stream resources if using async memset if use_async_memset: @@ -246,28 +252,23 @@ def _moe_core_impl( ) ) - # Step 3: Async moe_output_memset on auxiliary stream + # Step 3: Zero the active output slice before GEMM2 finalize. + # Finalize uses atomic scatter-add into `moe_output`, so it must start + # from zero each call. We zero only the active slice, not the full + # preallocated buffer. We do not use `moe_output_memset` here because + # FlashInfer's port always invokes the sparse kernel, missing the + # TRT-LLM dispatch that falls back to cudaMemsetAsync (dense zero) + # when !enable_alltoall || ep_size <= top_k. A dense zero of the + # active slice is correct for all configurations. + # TODO: add the TRTLLM all-to-all and `moe_output_memset` behavior if use_async_memset: - max_num_permuted_tokens = get_max_num_permuted_tokens( - num_tokens, top_k, num_local_experts, tile_size - ) with torch.cuda.stream(aux_stream): main_event.wait() - moe_output_memset( - output=moe_output, - tile_idx_to_mn_limit=tile_idx_to_mn_limit, - expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, - permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, - num_non_exiting_tiles=num_non_exiting_tiles, - max_num_permuted_tokens=max_num_permuted_tokens, - top_k=top_k, - tile_size=tile_size, - ) + moe_output.zero_() memset_event.record() memset_event.wait() else: - # Simple zero without async - moe_output[:num_tokens].zero_() + moe_output.zero_() # Step 4: GEMM2 + Finalize blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( @@ -431,7 +432,8 @@ def _allocate_buffers(self) -> None: (scale_size,), dtype=torch.uint8, device=self.device ) - # Final output + # Final output — sliced to [:num_tokens] before each forward pass, + # then zeroed before GEMM2 finalize, typically on aux_stream. self._moe_output = torch.empty( (self.max_num_tokens, self.hidden_size), dtype=self.output_dtype, @@ -497,7 +499,8 @@ def _forward_with_tactic( gemm1_out_scale=self._gemm1_output_scale if self.use_cuda_graph else None, moe_output=moe_output if moe_output is not None - else (self._moe_output if self.use_cuda_graph else None), + # Slice the CUDA-graph buffer to the active batch. + else (self._moe_output[: x.shape[0]] if self.use_cuda_graph else None), aux_stream=self._aux_stream, main_event=self._main_event, memset_event=self._memset_event, @@ -550,9 +553,10 @@ def run( f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})" ) - # Allocate output buffer if not using pre-allocated one + # Slice the pre-allocated buffer to the active batch so that + # _moe_core_impl only zeros num_tokens rows, not max_num_tokens. if self.use_cuda_graph: - moe_output = self._moe_output + moe_output = self._moe_output[:num_tokens] else: moe_output = torch.empty( (num_tokens, self.hidden_size),