diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 232b0b0daff6..eb125d7b36b3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -487,11 +487,13 @@ def schedule( if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] self.scheduler_metadata[:n] = scheduler_metadata - # NOTE(woosuk): We should zero out the rest of the scheduler - # metadata to guarantee the correctness. Otherwise, some thread - # blocks may use the invalid scheduler metadata and overwrite the - # output buffer. - self.scheduler_metadata[n:] = 0 + # NOTE(woosuk, lucas): Zero from n-1 onwards. Positions >= n must be + # zeroed to prevent invalid metadata from being used. The + # semaphore at position n-1 must also be zeroed before each + # forward pass because when num_splits == 1, FA3's internal + # semaphore reset uses PyTorch zero_() which isn't captured in + # CUDA graphs. + self.scheduler_metadata[n - 1 :] = 0 scheduler_metadata = self.scheduler_metadata[:n] attn_metadata = FlashAttentionMetadata( diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index f0ba259362ff..0f07af238e82 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -224,11 +224,13 @@ def _build_decode( f"{self.scheduler_metadata.shape[0]}" ) self.scheduler_metadata[:n] = scheduler_metadata - # NOTE(woosuk): We should zero out the rest of the scheduler - # metadata to guarantee the correctness. Otherwise, some thread - # blocks may use the invalid scheduler metadata and overwrite the - # output buffer. - self.scheduler_metadata[n:] = 0 + # NOTE(woosuk, lucas): Zero from n-1 onwards. Positions >= n must be + # zeroed to prevent invalid metadata from being used. The + # semaphore at position n-1 must also be zeroed before each + # forward pass because when num_splits == 1, FA3's internal + # semaphore reset uses PyTorch zero_() which isn't captured in + # CUDA graphs. + self.scheduler_metadata[n - 1 :] = 0 scheduler_metadata = self.scheduler_metadata[:n] metadata = FlashAttnMLADecodeMetadata(