diff --git a/flashinfer/fused_moe/cute_dsl/fused_moe.py b/flashinfer/fused_moe/cute_dsl/fused_moe.py index 8558bf92c2..af15ac0e7e 100644 --- a/flashinfer/fused_moe/cute_dsl/fused_moe.py +++ b/flashinfer/fused_moe/cute_dsl/fused_moe.py @@ -388,11 +388,18 @@ def __init__( self.device = device self.enable_pdl = enable_pdl - # Pre-allocated buffers + # Pre-allocated tensor buffers (only when CUDA graphs are enabled) self._moe_sort_buffers: Optional[Dict[str, torch.Tensor]] = None self._gemm1_output: Optional[torch.Tensor] = None self._gemm1_output_scale: Optional[torch.Tensor] = None self._moe_output: Optional[torch.Tensor] = None + + # Async memset resources: only created when CUDA graphs are enabled. + # With graphs, the same moe_output buffer is reused on every replay, + # so zeroing it on a separate stream overlaps with GEMM1. Without + # graphs, buffers are fresh each call and a synchronous .zero_() + # on the main stream suffices — using the module-level singleton + # stream would race when multiple MoE layers run concurrently. self._aux_stream: Optional[torch.cuda.Stream] = None self._main_event: Optional[torch.cuda.Event] = None self._memset_event: Optional[torch.cuda.Event] = None @@ -410,6 +417,9 @@ def __init__( ) if use_cuda_graph: + self._aux_stream = torch.cuda.Stream(device=self.device) + self._main_event = torch.cuda.Event() + self._memset_event = torch.cuda.Event() self._allocate_buffers() def _allocate_buffers(self) -> None: @@ -444,18 +454,14 @@ def _allocate_buffers(self) -> None: ) # Final output — sliced to [:num_tokens] before each forward pass, - # then zeroed before GEMM2 finalize, typically on aux_stream. + # then zeroed before GEMM2 finalize. Graph-enabled wrappers overlap + # the zero on aux_stream; non-graph wrappers zero on the main stream. self._moe_output = torch.empty( (self.max_num_tokens, self.hidden_size), dtype=self.output_dtype, device=self.device, ) - # CUDA resources - self._aux_stream = torch.cuda.Stream(device=self.device) - self._main_event = torch.cuda.Event() - self._memset_event = torch.cuda.Event() - def _forward_with_tactic( self, x: torch.Tensor, @@ -517,11 +523,11 @@ def _forward_with_tactic( if moe_output is not None # Slice the CUDA-graph buffer to the active batch. else (self._moe_output[: x.shape[0]] if use_prealloc else None), - aux_stream=self._aux_stream, - main_event=self._main_event, - memset_event=self._memset_event, + aux_stream=self._aux_stream if use_prealloc else None, + main_event=self._main_event if use_prealloc else None, + memset_event=self._memset_event if use_prealloc else None, output_dtype=output_dtype, - use_async_memset=True, + use_async_memset=use_prealloc, enable_pdl=enable_pdl, ) @@ -675,7 +681,7 @@ def _cute_dsl_fused_moe_nvfp4_impl( moe_output=moe_output, aux_stream=aux_stream, output_dtype=output_dtype, - use_async_memset=True, + use_async_memset=aux_stream is not None, enable_pdl=enable_pdl, )