201201from vllm .attention .backends .utils import get_mla_dims
202202from vllm .attention .ops .merge_attn_states import merge_attn_states
203203from vllm .attention .utils .fa_utils import get_flash_attn_version
204+ from vllm .config import VllmConfig
204205from vllm .logger import init_logger
205206from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
206207 LinearBase ,
211212 AttentionMetadataBuilder , CommonAttentionMetadata ,
212213 reoder_batch_to_split_decodes_and_prefills , split_decodes_and_prefills )
213214from vllm .v1 .kv_cache_interface import AttentionSpec
214- from vllm .v1 .worker .block_table import BlockTable
215215
216216try :
217217 from vllm .vllm_flash_attn import flash_attn_varlen_func
225225if TYPE_CHECKING :
226226 from vllm .v1 .core .sched .output import SchedulerOutput
227227 from vllm .v1 .worker .gpu_input_batch import InputBatch
228- from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
229228
230229logger = init_logger (__name__ )
231230
@@ -346,22 +345,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
346345 """
347346
348347 def __init__ (self ,
349- runner : "GPUModelRunner" ,
350348 kv_cache_spec : AttentionSpec ,
351- block_table : BlockTable ,
349+ vllm_config : VllmConfig ,
350+ device : torch .device ,
352351 metadata_cls : Optional [type [M ]] = None ):
353352 self .metadata_cls = metadata_cls \
354353 if metadata_cls is not None else MLACommonMetadata
355- self .runner = runner
356- scheduler_config = runner .scheduler_config
357- model_config = runner .model_config
358- cache_config = runner .cache_config
354+ self .kv_cache_spec = kv_cache_spec
355+ self .device = device
356+ scheduler_config = vllm_config .scheduler_config
357+ self .model_config = vllm_config .model_config
358+ cache_config = vllm_config .cache_config
359+ parallel_config = vllm_config .parallel_config
359360 self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
360- self .num_heads = model_config .get_num_attention_heads (
361- runner . parallel_config )
362- self .mla_dims = get_mla_dims (model_config )
361+ self .num_heads = self . model_config .get_num_attention_heads (
362+ parallel_config )
363+ self .mla_dims = get_mla_dims (self . model_config )
363364 self .aot_schedule = current_platform .is_cuda ()
364- self .kv_cache_spec = kv_cache_spec
365365
366366 # Dont try to access the runner on AMD
367367 if self .aot_schedule :
@@ -372,7 +372,7 @@ def __init__(self,
372372 # Max sure there is enough for 8 full length request or at least
373373 # 4 pages of cache per request
374374 max (
375- 8 * model_config .max_model_len , 4 *
375+ 8 * self . model_config .max_model_len , 4 *
376376 scheduler_config .max_num_seqs * cache_config .block_size ),
377377 # For long-context models try not to over-allocate limiting
378378 # kv-cache space, limiting it to 64k tokens,
@@ -387,11 +387,10 @@ def __init__(self,
387387 scheduler_config .max_num_seqs * cache_config .block_size
388388 self .chunked_prefill_workspace = torch .empty (
389389 (self .chunked_prefill_workspace_size ,
390- model_config .get_head_size ()),
391- dtype = model_config .dtype ,
392- device = runner . device ,
390+ self . model_config .get_head_size ()),
391+ dtype = self . model_config .dtype ,
392+ device = device ,
393393 )
394- self .block_table = block_table
395394
396395 def reorder_batch (self , input_batch : "InputBatch" ,
397396 scheduler_output : "SchedulerOutput" ) -> bool :
@@ -432,7 +431,7 @@ def build(self,
432431 # Note(simon): be careful about the CPU <> GPU memory movement in this
433432 # function. We should avoid GPU -> CPU sync as much as possible because
434433 # it blocks on all previous kernels.
435- device = self .runner . device
434+ device = self .device
436435 block_table_tensor = common_attn_metadata .block_table_tensor
437436 slot_mapping = common_attn_metadata .slot_mapping
438437
@@ -538,7 +537,7 @@ def build(self,
538537 num_actual_tokens = num_tokens ,
539538 query_start_loc = query_start_loc ,
540539 slot_mapping = slot_mapping ,
541- head_dim = self .runner . model_config .get_head_size (),
540+ head_dim = self .model_config .get_head_size (),
542541 # MLACommonMetadata Chunk prefill specific
543542 num_decodes = num_decodes ,
544543 num_decode_tokens = num_decode_tokens ,
0 commit comments