diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 66697132b365..3793a83e74e1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -71,7 +71,6 @@ def __init__( self.device = device self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len - self.block_size = vllm_config.cache_config.block_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.num_speculative_tokens = self.speculative_config.num_speculative_tokens self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens @@ -470,22 +469,23 @@ def propose( common_attn_metadata._num_computed_tokens_cpu += 1 # Compute the slot mapping. + block_size = attn_metadata_builder.kv_cache_spec.block_size if self.uses_mrope: # all dimensions of positions are the same - block_numbers = clamped_positions[0] // self.block_size + block_numbers = clamped_positions[0] // block_size else: - block_numbers = clamped_positions // self.block_size + block_numbers = clamped_positions // block_size block_ids = common_attn_metadata.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1) ) block_ids = block_ids.view(-1) if self.uses_mrope: common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + clamped_positions[0] % self.block_size + block_ids * block_size + clamped_positions[0] % block_size ) else: common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + clamped_positions % self.block_size + block_ids * block_size + clamped_positions % block_size ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the @@ -800,12 +800,11 @@ def propose_tree( attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. + block_size = tree_attn_metadata_builder.kv_cache_spec.block_size query_positions = flattened_draft_positions[:, level : level + query_len] - block_numbers = query_positions // self.block_size + block_numbers = query_positions // block_size block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers) - slot_mapping = ( - block_ids * self.block_size + query_positions % self.block_size - ) + slot_mapping = block_ids * block_size + query_positions % block_size # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens.