Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1.
token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16
out: Optional output tensor, shape (seq_len, n). Created if None.
This tensor is used for atomic accumulation, so it should be zero-initialized.
This tensor is used for atomic accumulation. If `out` is
provided, it must already be zero-initialized by the caller.
If `out` is None, this function allocates a zero-initialized
output tensor. Passing a non-zeroed `out` buffer will silently
produce incorrect results.
Comment on lines +301 to +305
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

This turns the public out= path into a silent accumulation trap.

blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4 is still a flashinfer_api, so removing the internal zero on caller-supplied buffers breaks existing out= call sites with wrong answers rather than a loud failure. Please keep the zero-free fast path internal, or gate it behind an explicit opt-in, and preserve overwrite semantics on the public entry point.

Also applies to: 321-325, 410-420

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 301 - 305, The public API function
blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4 currently removes
zero-initialization on caller-supplied out buffers, turning legitimate out=
usage into a silent accumulation bug; restore the original overwrite semantics
so when callers pass a non-None out buffer it is always zeroed (or explicitly
documented to be required zeroed), and move the zero-free fast path into an
internal helper (or gate it behind an explicit opt-in flag) used only by
internal callers; update the implementations referenced around the other similar
sites (the same pattern at the blocks around lines 321-325 and 410-420) to
follow the same approach so public entry points preserve overwrite semantics
while internal optimized paths can bypass zeroing when explicitly opted-in.

ab_dtype: Data type for A and B matrices. Default: "float4_e2m1fn"
sf_dtype: Data type for scale factors. Default: "float8_e4m3fn"
out_dtype: Data type for output matrix. Default: "bfloat16"
Expand All @@ -314,6 +318,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(

Notes:
- The output tensor is modified in-place using atomic adds for scatter-reduction.
- When out is provided it is NOT zeroed internally; the caller
must ensure the buffer is zeroed before each invocation.
In the main CuteDSL MoE path, _moe_core_impl handles this by
zeroing the active output slice before GEMM2, typically on an
auxiliary stream overlapped with GEMM1.
- Call create_finalize_fusion_tensors() to create permuted_idx_to_expanded_idx and token_final_scales.
- Requires SM100 (Blackwell) GPU architecture
- The finalize fusion eliminates the need for a separate moe_unpermute kernel
Expand Down Expand Up @@ -398,16 +407,17 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
f"cluster_shape_mn={cluster_shape_mn}, shape=({permuted_m}, {n}, {k}, {num_experts})"
)

# Create output tensor if not provided (zero-initialized for atomic adds)
# Create output tensor if not provided (zero-initialized for atomic adds).
# If out is provided, the caller is responsible for zeroing it before
# this call. The GEMM2 epilogue uses atomic scatter-add
# (out[token_idx] += ...), so any non-zero residual would corrupt
# results.
if out is None:
out = torch.zeros(
(seq_len, n),
dtype=cutlass_to_torch_dtype(out_dtype_cutlass),
device=a.device,
)
else:
# Ensure output is zero for proper accumulation
out.zero_()

# Get SM count
if sm_count is None:
Expand Down
50 changes: 27 additions & 23 deletions flashinfer/fused_moe/cute_dsl/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
from .moe_utils import (
allocate_moe_sort_buffers,
get_max_num_permuted_tokens,
moe_output_memset,
moe_sort,
)
from .blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion import (
Expand Down Expand Up @@ -143,7 +142,7 @@ def _moe_core_impl(
This function handles:
1. moe_sort: Token routing computation
2. GEMM1 + SwiGLU: First projection with activation
3. Async moe_output_memset: Zero output buffer (overlapped with GEMM1)
3. Async output zero: Zero output buffer (overlapped with GEMM1)
4. GEMM2 + Finalize: Second projection with atomic scatter

Args:
Expand Down Expand Up @@ -183,13 +182,20 @@ def _moe_core_impl(
num_tokens = token_selected_experts.size(0)
hidden_size = w2_weight.size(1)

# Allocate output if not provided
# Allocate output if not provided. The caller (wrapper or functional
# API) should pass a [:num_tokens] slice of the pre-allocated buffer
# when using CUDA graphs. The buffer is zeroed in Step 3 below.
if moe_output is None:
moe_output = torch.empty(
(num_tokens, hidden_size),
dtype=output_dtype,
device=x.device,
)
else:
assert moe_output.size(0) == num_tokens, (
f"moe_output must be sliced to num_tokens rows before calling "
f"_moe_core_impl (got {moe_output.size(0)}, expected {num_tokens})"
)
Comment on lines +195 to +198
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion is a valuable safeguard. To improve debuggability when this assertion fails, I recommend enhancing the error message to include the actual and expected tensor sizes. This provides immediate, actionable context to the developer, reducing debugging time.

Suggested change
assert moe_output.size(0) == num_tokens, (
"moe_output must be sliced to num_tokens rows before calling _moe_core_impl"
)
assert moe_output.size(0) == num_tokens, (
f"moe_output has {moe_output.size(0)} rows, but expected {num_tokens}. "
"It must be sliced to num_tokens rows before calling _moe_core_impl."
)

Comment on lines +185 to +198
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't require an exact-row moe_output here.

The new assert breaks callers that reuse a larger pre-allocated output buffer, even though the optimization only needs a [:num_tokens] view. Accept size(0) >= num_tokens, validate the hidden dimension, and slice locally so the fast path stays intact without changing the public contract.

Proposed fix
     if moe_output is None:
         moe_output = torch.empty(
             (num_tokens, hidden_size),
             dtype=output_dtype,
             device=x.device,
         )
     else:
-        assert moe_output.size(0) == num_tokens, (
-            "moe_output must be sliced to num_tokens rows before calling _moe_core_impl"
-        )
+        if moe_output.size(0) < num_tokens or moe_output.size(1) != hidden_size:
+            raise ValueError(
+                "moe_output must have shape [>= num_tokens, hidden_size]"
+            )
+        moe_output = moe_output[:num_tokens]
+        if not moe_output.is_contiguous():
+            raise ValueError("moe_output[:num_tokens] must be contiguous")
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 185 - 197, The
assert in the _moe_core_impl input handling is too strict: allow callers to pass
a larger preallocated moe_output buffer by checking moe_output.size(0) >=
num_tokens and moe_output.size(1) == hidden_size (validate hidden dimension and
dtype/device if desired), then locally slice moe_output =
moe_output[:num_tokens] before using it; keep the existing allocation path when
moe_output is None (using torch.empty((num_tokens, hidden_size),
dtype=output_dtype, device=x.device)) so the fast path and CUDA-graph slice
semantics remain intact.


# Get stream resources if using async memset
if use_async_memset:
Expand Down Expand Up @@ -246,28 +252,23 @@ def _moe_core_impl(
)
)

# Step 3: Async moe_output_memset on auxiliary stream
# Step 3: Zero the active output slice before GEMM2 finalize.
# Finalize uses atomic scatter-add into `moe_output`, so it must start
# from zero each call. We zero only the active slice, not the full
# preallocated buffer. We do not use `moe_output_memset` here because
# FlashInfer's port always invokes the sparse kernel, missing the
# TRT-LLM dispatch that falls back to cudaMemsetAsync (dense zero)
# when !enable_alltoall || ep_size <= top_k. A dense zero of the
# active slice is correct for all configurations.
# TODO: add the TRTLLM all-to-all and `moe_output_memset` behavior
if use_async_memset:
max_num_permuted_tokens = get_max_num_permuted_tokens(
num_tokens, top_k, num_local_experts, tile_size
)
with torch.cuda.stream(aux_stream):
main_event.wait()
moe_output_memset(
output=moe_output,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
max_num_permuted_tokens=max_num_permuted_tokens,
top_k=top_k,
tile_size=tile_size,
)
moe_output.zero_()
memset_event.record()
memset_event.wait()
else:
# Simple zero without async
moe_output[:num_tokens].zero_()
moe_output.zero_()

# Step 4: GEMM2 + Finalize
blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
Expand Down Expand Up @@ -431,7 +432,8 @@ def _allocate_buffers(self) -> None:
(scale_size,), dtype=torch.uint8, device=self.device
)

# Final output
# Final output β€” sliced to [:num_tokens] before each forward pass,
# then zeroed before GEMM2 finalize, typically on aux_stream.
self._moe_output = torch.empty(
(self.max_num_tokens, self.hidden_size),
dtype=self.output_dtype,
Expand Down Expand Up @@ -497,7 +499,8 @@ def _forward_with_tactic(
gemm1_out_scale=self._gemm1_output_scale if self.use_cuda_graph else None,
moe_output=moe_output
if moe_output is not None
else (self._moe_output if self.use_cuda_graph else None),
# Slice the CUDA-graph buffer to the active batch.
else (self._moe_output[: x.shape[0]] if self.use_cuda_graph else None),
aux_stream=self._aux_stream,
main_event=self._main_event,
memset_event=self._memset_event,
Expand Down Expand Up @@ -550,9 +553,10 @@ def run(
f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})"
)

# Allocate output buffer if not using pre-allocated one
# Slice the pre-allocated buffer to the active batch so that
# _moe_core_impl only zeros num_tokens rows, not max_num_tokens.
if self.use_cuda_graph:
moe_output = self._moe_output
moe_output = self._moe_output[:num_tokens]
else:
moe_output = torch.empty(
(num_tokens, self.hidden_size),
Expand Down
Loading