Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1.
token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16
out: Optional output tensor, shape (seq_len, n). Created if None.
This tensor is used for atomic accumulation, so it should be zero-initialized.
Must be zero-initialized by the caller when provided, as this kernel
uses atomic adds for scatter-reduction.
ab_dtype: Data type for A and B matrices. Default: "float4_e2m1fn"
sf_dtype: Data type for scale factors. Default: "float8_e4m3fn"
out_dtype: Data type for output matrix. Default: "bfloat16"
Expand Down Expand Up @@ -404,9 +405,6 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
dtype=cutlass_to_torch_dtype(out_dtype_cutlass),
device=a.device,
)
else:
# Ensure output is zero for proper accumulation
out.zero_()

# Get SM count
if sm_count is None:
Expand Down
23 changes: 6 additions & 17 deletions flashinfer/fused_moe/cute_dsl/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
from .moe_utils import (
allocate_moe_sort_buffers,
get_max_num_permuted_tokens,
moe_output_memset,
moe_sort,
)
from .blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion import (
Expand Down Expand Up @@ -246,30 +245,20 @@ def _moe_core_impl(
)
)

# Step 3: Async moe_output_memset on auxiliary stream
# Step 3: Zero output buffer (overlapped with GEMM1 on auxiliary stream)
# The finalize kernel uses atomic adds, so the output must be zeroed first.
# We zero the full output on the aux stream (overlapping with GEMM1) and
# skip the redundant zero inside the finalize kernel wrapper.
if use_async_memset:
max_num_permuted_tokens = get_max_num_permuted_tokens(
num_tokens, top_k, num_local_experts, tile_size
)
with torch.cuda.stream(aux_stream):
main_event.wait()
moe_output_memset(
output=moe_output,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
max_num_permuted_tokens=max_num_permuted_tokens,
top_k=top_k,
tile_size=tile_size,
)
moe_output[:num_tokens].zero_()
memset_event.record()
memset_event.wait()
else:
# Simple zero without async
moe_output[:num_tokens].zero_()

# Step 4: GEMM2 + Finalize
# Step 4: GEMM2 + Finalize (output already zeroed, skip internal zero)
blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
a=intermediate,
b=w2_weight,
Expand Down
Loading