@@ -164,6 +164,24 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
164164 raise NotImplementedError (
165165 "Non-Attention backend is not supported by V1 NPUModelRunner." )
166166
167+ self .attn_backend = get_attn_backend (
168+ self .head_size ,
169+ self .dtype ,
170+ self .kv_cache_dtype ,
171+ self .block_size ,
172+ self .model_config .is_attention_free ,
173+ use_mla = self .model_config .use_mla ,
174+ )
175+ if self .attn_backend is None :
176+ error_msg = (
177+ f"Error with get_att_backend: { self .head_size = } , "
178+ f"{ self .dtype = } , { self .kv_cache_dtype = } , { self .block_size = } , "
179+ f"{ self .model_config .is_attention_free = } , "
180+ f"{ self .model_config .use_mla = } " )
181+ logger .error (error_msg )
182+ raise NotImplementedError (
183+ "Non-Attention backend is not supported by V1 GPUModelRunner." )
184+
167185 self .attn_metadata_builder = self .attn_backend .get_builder_cls ()(
168186 weakref .proxy (self ))
169187
@@ -196,6 +214,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
196214 pin_memory = True ,
197215 vocab_size = self .model_config .get_vocab_size (),
198216 )
217+ else :
218+ self .input_batch = InputBatch (
219+ max_num_reqs = self .max_num_reqs ,
220+ max_model_len = self .model_config .max_model_len ,
221+ max_num_blocks_per_req = self .max_num_blocks_per_req ,
222+ max_num_batched_tokens = self .max_num_tokens ,
223+ device = self .device ,
224+ pin_memory = True ,
225+ vocab_size = self .model_config .get_vocab_size (),
226+ )
227+
199228 self .input_ids = torch .zeros (self .max_num_tokens ,
200229 dtype = torch .int32 ,
201230 device = self .device )
@@ -542,10 +571,7 @@ def _process_reqs(
542571
543572 block_table_indices = (req_indices * self .max_num_blocks_per_req +
544573 positions_np // self .block_size )
545- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
546- block_table_cpu = self .input_batch .block_table .get_cpu_tensor ()
547- else :
548- block_table_cpu = self .input_batch .block_table [0 ].get_cpu_tensor ()
574+ block_table_cpu = self .input_batch .block_table .get_cpu_tensor ()
549575 block_numbers = block_table_cpu .flatten ()[block_table_indices ].numpy ()
550576 block_offsets = positions_np % self .block_size
551577 np .add (block_numbers * self .block_size ,
@@ -960,16 +986,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
960986 """
961987 import torch_npu
962988 kv_caches : Dict [str , torch .Tensor ] = {}
963- if not (vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" )):
964- self .input_batch = InputBatch (
965- max_num_reqs = self .max_num_reqs ,
966- max_model_len = self .model_config .max_model_len ,
967- max_num_batched_tokens = self .max_num_tokens ,
968- device = self .device ,
969- pin_memory = True ,
970- vocab_size = self .model_config .get_vocab_size (),
971- kv_cache_config = kv_cache_config ,
972- )
973989
974990 for kv_cache_group in kv_cache_config .kv_cache_groups :
975991 kv_cache_spec = kv_cache_group .kv_cache_spec
0 commit comments