Skip to content

Commit 538c9a0

Browse files
Update type hints in gdn_attn
Signed-off-by: simondanielsson <[email protected]>
1 parent 7b41ac4 commit 538c9a0

File tree

1 file changed

+35
-35
lines changed

1 file changed

+35
-35
lines changed

vllm/v1/attention/backends/gdn_attn.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,9 @@ class GDNAttentionMetadata:
3636
num_spec_decode_tokens: int
3737
num_actual_tokens: int
3838

39-
has_initial_state: Optional[torch.Tensor] = None
40-
block_size: Optional[int] = None
41-
chunk_size: Optional[int] = None
4239
has_initial_state: torch.Tensor | None = None
40+
block_size: int | None = None
41+
chunk_size: int | None = None
4342

4443
spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,]
4544
non_spec_query_start_loc: torch.Tensor | None = (
@@ -57,19 +56,19 @@ class GDNAttentionMetadata:
5756
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
5857

5958
# Decode-side APC metadata
60-
state_indices_tensor_d: Optional[torch.Tensor] = None
61-
state_indices_tensor_p: Optional[torch.Tensor] = None
62-
block_idx_last_computed_token_d: Optional[torch.Tensor] = None
63-
block_idx_last_scheduled_token_d: Optional[torch.Tensor] = None
59+
state_indices_tensor_d: torch.Tensor | None = None
60+
state_indices_tensor_p: torch.Tensor | None = None
61+
block_idx_last_computed_token_d: torch.Tensor | None = None
62+
block_idx_last_scheduled_token_d: torch.Tensor | None = None
6463

6564
# Prefill-side APC metadata
66-
block_idx_first_scheduled_token_p: Optional[torch.Tensor] = None
67-
block_idx_last_computed_token_p: Optional[torch.Tensor] = None
68-
block_idx_last_scheduled_token_p: Optional[torch.Tensor] = None
69-
seq_idx_p: Optional[torch.Tensor] = None
70-
cu_chunk_seqlen_p: Optional[torch.Tensor] = None
71-
last_chunk_indices_p: Optional[torch.Tensor] = None
72-
num_computed_tokens_p: Optional[torch.Tensor] = None
65+
block_idx_first_scheduled_token_p: torch.Tensor | None = None
66+
block_idx_last_computed_token_p: torch.Tensor | None = None
67+
block_idx_last_scheduled_token_p: torch.Tensor | None = None
68+
seq_idx_p: torch.Tensor | None = None
69+
cu_chunk_seqlen_p: torch.Tensor | None = None
70+
last_chunk_indices_p: torch.Tensor | None = None
71+
num_computed_tokens_p: torch.Tensor | None = None
7372

7473
# The following attributes are for triton implementation of causal_conv1d
7574
nums_dict: dict | None = None
@@ -103,12 +102,13 @@ def __init__(
103102
self._init_reorder_batch_threshold(1, self.use_spec_decode)
104103

105104
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() or 64
106-
if self.vllm_config.cache_config.enable_prefix_caching:
107-
if kv_cache_spec.block_size % self.chunk_size != 0:
108-
raise ValueError(
109-
"GDN prefix caching requires the mamba block size to be a "
110-
"multiple of the kernel chunk size."
111-
)
105+
if self.vllm_config.cache_config.enable_prefix_caching and (
106+
kv_cache_spec.block_size % self.chunk_size != 0
107+
):
108+
raise ValueError(
109+
"GDN prefix caching requires the mamba block size to be a "
110+
"multiple of the kernel chunk size."
111+
)
112112

113113
self.use_full_cuda_graph = (
114114
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
@@ -247,23 +247,23 @@ def build( # type: ignore[override]
247247
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
248248

249249
enable_apc = self.vllm_config.cache_config.enable_prefix_caching
250-
block_size_value: Optional[int] = None
251-
chunk_size_value: Optional[int] = None
250+
block_size_value: int | None = None
251+
chunk_size_value: int | None = None
252252
if enable_apc:
253253
block_size_value = self.kv_cache_spec.block_size
254254
chunk_size_value = self.chunk_size
255-
state_indices_tensor_d: Optional[torch.Tensor] = None
256-
state_indices_tensor_p: Optional[torch.Tensor] = None
257-
block_idx_last_computed_token_d: Optional[torch.Tensor] = None
258-
block_idx_last_scheduled_token_d: Optional[torch.Tensor] = None
259-
block_idx_first_scheduled_token_p: Optional[torch.Tensor] = None
260-
block_idx_last_computed_token_p: Optional[torch.Tensor] = None
261-
block_idx_last_scheduled_token_p: Optional[torch.Tensor] = None
262-
num_computed_tokens_p: Optional[torch.Tensor] = None
263-
seq_idx_p: Optional[torch.Tensor] = None
264-
cu_chunk_seqlen_p: Optional[torch.Tensor] = None
265-
last_chunk_indices_p: Optional[torch.Tensor] = None
266-
non_spec_query_start_loc_cpu: Optional[torch.Tensor] = None
255+
state_indices_tensor_d: torch.Tensor | None = None
256+
state_indices_tensor_p: torch.Tensor | None = None
257+
block_idx_last_computed_token_d: torch.Tensor | None = None
258+
block_idx_last_scheduled_token_d: torch.Tensor | None = None
259+
block_idx_first_scheduled_token_p: torch.Tensor | None = None
260+
block_idx_last_computed_token_p: torch.Tensor | None = None
261+
block_idx_last_scheduled_token_p: torch.Tensor | None = None
262+
num_computed_tokens_p: torch.Tensor | None = None
263+
seq_idx_p: torch.Tensor | None = None
264+
cu_chunk_seqlen_p: torch.Tensor | None = None
265+
last_chunk_indices_p: torch.Tensor | None = None
266+
non_spec_query_start_loc_cpu: torch.Tensor | None = None
267267

268268
if (
269269
not self.use_spec_decode
@@ -277,7 +277,7 @@ def build( # type: ignore[override]
277277
num_spec_decodes = 0
278278
else:
279279
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
280-
num_spec_decodes = spec_sequence_masks.sum().item()
280+
num_spec_decodes = int(spec_sequence_masks.sum().item())
281281
if num_spec_decodes == 0:
282282
spec_sequence_masks = None
283283
else:

0 commit comments

Comments
 (0)