diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index e8921f8a1c40..c6e3f92dc4a6 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -88,6 +88,13 @@ def __init__( # TODO: we can disambiguate between decode and mixed-prefill decode here # so we can only use the persistent buffer if a cudagraph is actually # being used. + + # paged_kv_last_page_len is always 1s (kernel block size is always 1), + # so we create it once and reuse slices in both eager and cudagraph modes. + self.paged_kv_last_page_len = torch.ones( + max_num_reqs, dtype=torch.int32, device=device + ) + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.paged_kv_indptr = torch.zeros( max_num_reqs + 1, dtype=torch.int32, device=device @@ -95,9 +102,6 @@ def __init__( self.paged_kv_indices = torch.zeros( max_num_pages, dtype=torch.int32, device=device ) - self.paged_kv_last_page_len = torch.zeros( - max_num_reqs, dtype=torch.int32, device=device - ) self.qo_indptr = torch.zeros( max_num_reqs + 1, dtype=torch.int32, device=device @@ -122,7 +126,9 @@ def _build_decode( ).unsqueeze(0) < seq_lens_device.unsqueeze(1) paged_kv_indices = block_table_tensor[mask] - paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device) + # kernel block size is always 1, so each page has exactly 1 token. + # last_page_len is always 1 - just slice the pre-initialized buffer. + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] paged_kv_indptr = torch.cat( [ @@ -148,11 +154,8 @@ def _build_decode( self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1]) paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs] - self.paged_kv_last_page_len[:num_reqs].copy_( - paged_kv_last_page_len, non_blocking=True - ) - self.paged_kv_last_page_len[num_reqs:].fill_(1) - paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + # paged_kv_last_page_len already uses the pre-initialized buffer slice + # (set above), so no copy needed - buffer is always 1s. self.qo_indptr[: 1 + num_reqs].copy_( query_start_loc_device, non_blocking=True