@@ -36,10 +36,9 @@ class GDNAttentionMetadata:
3636 num_spec_decode_tokens : int
3737 num_actual_tokens : int
3838
39- has_initial_state : Optional [torch .Tensor ] = None
40- block_size : Optional [int ] = None
41- chunk_size : Optional [int ] = None
4239 has_initial_state : torch .Tensor | None = None
40+ block_size : int | None = None
41+ chunk_size : int | None = None
4342
4443 spec_query_start_loc : torch .Tensor | None = None # shape: [num_spec_decodes + 1,]
4544 non_spec_query_start_loc : torch .Tensor | None = (
@@ -57,19 +56,19 @@ class GDNAttentionMetadata:
5756 num_accepted_tokens : torch .Tensor | None = None # shape: [batch,]
5857
5958 # Decode-side APC metadata
60- state_indices_tensor_d : Optional [ torch .Tensor ] = None
61- state_indices_tensor_p : Optional [ torch .Tensor ] = None
62- block_idx_last_computed_token_d : Optional [ torch .Tensor ] = None
63- block_idx_last_scheduled_token_d : Optional [ torch .Tensor ] = None
59+ state_indices_tensor_d : torch .Tensor | None = None
60+ state_indices_tensor_p : torch .Tensor | None = None
61+ block_idx_last_computed_token_d : torch .Tensor | None = None
62+ block_idx_last_scheduled_token_d : torch .Tensor | None = None
6463
6564 # Prefill-side APC metadata
66- block_idx_first_scheduled_token_p : Optional [ torch .Tensor ] = None
67- block_idx_last_computed_token_p : Optional [ torch .Tensor ] = None
68- block_idx_last_scheduled_token_p : Optional [ torch .Tensor ] = None
69- seq_idx_p : Optional [ torch .Tensor ] = None
70- cu_chunk_seqlen_p : Optional [ torch .Tensor ] = None
71- last_chunk_indices_p : Optional [ torch .Tensor ] = None
72- num_computed_tokens_p : Optional [ torch .Tensor ] = None
65+ block_idx_first_scheduled_token_p : torch .Tensor | None = None
66+ block_idx_last_computed_token_p : torch .Tensor | None = None
67+ block_idx_last_scheduled_token_p : torch .Tensor | None = None
68+ seq_idx_p : torch .Tensor | None = None
69+ cu_chunk_seqlen_p : torch .Tensor | None = None
70+ last_chunk_indices_p : torch .Tensor | None = None
71+ num_computed_tokens_p : torch .Tensor | None = None
7372
7473 # The following attributes are for triton implementation of causal_conv1d
7574 nums_dict : dict | None = None
@@ -103,12 +102,13 @@ def __init__(
103102 self ._init_reorder_batch_threshold (1 , self .use_spec_decode )
104103
105104 self .chunk_size = vllm_config .model_config .get_mamba_chunk_size () or 64
106- if self .vllm_config .cache_config .enable_prefix_caching :
107- if kv_cache_spec .block_size % self .chunk_size != 0 :
108- raise ValueError (
109- "GDN prefix caching requires the mamba block size to be a "
110- "multiple of the kernel chunk size."
111- )
105+ if self .vllm_config .cache_config .enable_prefix_caching and (
106+ kv_cache_spec .block_size % self .chunk_size != 0
107+ ):
108+ raise ValueError (
109+ "GDN prefix caching requires the mamba block size to be a "
110+ "multiple of the kernel chunk size."
111+ )
112112
113113 self .use_full_cuda_graph = (
114114 self .compilation_config .cudagraph_mode .has_full_cudagraphs ()
@@ -247,23 +247,23 @@ def build( # type: ignore[override]
247247 nums_dict , batch_ptr , token_chunk_offset_ptr = None , None , None
248248
249249 enable_apc = self .vllm_config .cache_config .enable_prefix_caching
250- block_size_value : Optional [ int ] = None
251- chunk_size_value : Optional [ int ] = None
250+ block_size_value : int | None = None
251+ chunk_size_value : int | None = None
252252 if enable_apc :
253253 block_size_value = self .kv_cache_spec .block_size
254254 chunk_size_value = self .chunk_size
255- state_indices_tensor_d : Optional [ torch .Tensor ] = None
256- state_indices_tensor_p : Optional [ torch .Tensor ] = None
257- block_idx_last_computed_token_d : Optional [ torch .Tensor ] = None
258- block_idx_last_scheduled_token_d : Optional [ torch .Tensor ] = None
259- block_idx_first_scheduled_token_p : Optional [ torch .Tensor ] = None
260- block_idx_last_computed_token_p : Optional [ torch .Tensor ] = None
261- block_idx_last_scheduled_token_p : Optional [ torch .Tensor ] = None
262- num_computed_tokens_p : Optional [ torch .Tensor ] = None
263- seq_idx_p : Optional [ torch .Tensor ] = None
264- cu_chunk_seqlen_p : Optional [ torch .Tensor ] = None
265- last_chunk_indices_p : Optional [ torch .Tensor ] = None
266- non_spec_query_start_loc_cpu : Optional [ torch .Tensor ] = None
255+ state_indices_tensor_d : torch .Tensor | None = None
256+ state_indices_tensor_p : torch .Tensor | None = None
257+ block_idx_last_computed_token_d : torch .Tensor | None = None
258+ block_idx_last_scheduled_token_d : torch .Tensor | None = None
259+ block_idx_first_scheduled_token_p : torch .Tensor | None = None
260+ block_idx_last_computed_token_p : torch .Tensor | None = None
261+ block_idx_last_scheduled_token_p : torch .Tensor | None = None
262+ num_computed_tokens_p : torch .Tensor | None = None
263+ seq_idx_p : torch .Tensor | None = None
264+ cu_chunk_seqlen_p : torch .Tensor | None = None
265+ last_chunk_indices_p : torch .Tensor | None = None
266+ non_spec_query_start_loc_cpu : torch .Tensor | None = None
267267
268268 if (
269269 not self .use_spec_decode
@@ -277,7 +277,7 @@ def build( # type: ignore[override]
277277 num_spec_decodes = 0
278278 else :
279279 spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
280- num_spec_decodes = spec_sequence_masks .sum ().item ()
280+ num_spec_decodes = int ( spec_sequence_masks .sum ().item () )
281281 if num_spec_decodes == 0 :
282282 spec_sequence_masks = None
283283 else :
0 commit comments