diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index eda50155ddac..1b030eaf140a 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -855,8 +855,12 @@ def mamba_get_block_table_tensor( (seq_lens - 1) // kv_cache_spec.block_size, min=0, ) + # Use int32 for arithmetic to avoid dtype promotion overhead, + # then convert to int64 for gather (which requires Long indices) offsets = torch.arange( - 1 + kv_cache_spec.num_speculative_blocks, device=block_table.device + 1 + kv_cache_spec.num_speculative_blocks, + device=block_table.device, + dtype=torch.int32, ) - indices_to_gather = start_indices.unsqueeze(1) + offsets + indices_to_gather = (start_indices.unsqueeze(1) + offsets).to(torch.int64) return torch.gather(block_table, 1, indices_to_gather)