Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions flashinfer/fused_moe/cute_dsl/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,18 @@ def __init__(
self.device = device
self.enable_pdl = enable_pdl

# Pre-allocated buffers
# Pre-allocated tensor buffers (only when CUDA graphs are enabled)
self._moe_sort_buffers: Optional[Dict[str, torch.Tensor]] = None
self._gemm1_output: Optional[torch.Tensor] = None
self._gemm1_output_scale: Optional[torch.Tensor] = None
self._moe_output: Optional[torch.Tensor] = None

# Async memset resources: only created when CUDA graphs are enabled.
# With graphs, the same moe_output buffer is reused on every replay,
# so zeroing it on a separate stream overlaps with GEMM1. Without
# graphs, buffers are fresh each call and a synchronous .zero_()
# on the main stream suffices β€” using the module-level singleton
# stream would race when multiple MoE layers run concurrently.
self._aux_stream: Optional[torch.cuda.Stream] = None
self._main_event: Optional[torch.cuda.Event] = None
self._memset_event: Optional[torch.cuda.Event] = None
Expand All @@ -410,6 +417,9 @@ def __init__(
)

if use_cuda_graph:
self._aux_stream = torch.cuda.Stream(device=self.device)
self._main_event = torch.cuda.Event()
self._memset_event = torch.cuda.Event()
self._allocate_buffers()

def _allocate_buffers(self) -> None:
Expand Down Expand Up @@ -444,18 +454,14 @@ def _allocate_buffers(self) -> None:
)

# Final output β€” sliced to [:num_tokens] before each forward pass,
# then zeroed before GEMM2 finalize, typically on aux_stream.
# then zeroed before GEMM2 finalize. Graph-enabled wrappers overlap
# the zero on aux_stream; non-graph wrappers zero on the main stream.
self._moe_output = torch.empty(
(self.max_num_tokens, self.hidden_size),
dtype=self.output_dtype,
device=self.device,
)

# CUDA resources
self._aux_stream = torch.cuda.Stream(device=self.device)
self._main_event = torch.cuda.Event()
self._memset_event = torch.cuda.Event()

def _forward_with_tactic(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -517,11 +523,11 @@ def _forward_with_tactic(
if moe_output is not None
# Slice the CUDA-graph buffer to the active batch.
else (self._moe_output[: x.shape[0]] if use_prealloc else None),
aux_stream=self._aux_stream,
main_event=self._main_event,
memset_event=self._memset_event,
aux_stream=self._aux_stream if use_prealloc else None,
main_event=self._main_event if use_prealloc else None,
memset_event=self._memset_event if use_prealloc else None,
output_dtype=output_dtype,
use_async_memset=True,
use_async_memset=use_prealloc,
enable_pdl=enable_pdl,
)

Expand Down Expand Up @@ -675,7 +681,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
moe_output=moe_output,
aux_stream=aux_stream,
output_dtype=output_dtype,
use_async_memset=True,
use_async_memset=aux_stream is not None,
enable_pdl=enable_pdl,
)

Expand Down
Loading