1212
1313import torch
1414
15- import vllm .envs as envs
16- from vllm import _custom_ops as ops
1715from vllm .config import VllmConfig
1816from vllm .distributed .kv_transfer .kv_connector .base import KVConnectorBase
17+ from vllm .distributed .kv_transfer .kv_connector .utils import (
18+ model_aware_kv_ops_helper as kv_helper )
1919from vllm .distributed .kv_transfer .kv_lookup_buffer .simple_buffer import (
2020 SimpleBuffer )
2121from vllm .logger import init_logger
@@ -37,9 +37,7 @@ def __init__(
3737 ):
3838
3939 self .config = config .kv_transfer_config
40- self .tp_size = config .parallel_config .tensor_parallel_size
41- self .is_deepseek_mla = config .model_config .is_deepseek_mla
42- self .use_mla_opt = not envs .VLLM_MLA_DISABLE
40+ self .kv_helper = kv_helper (config )
4341
4442 if self .config .kv_connector == "PyNcclConnector" :
4543 from vllm .distributed .kv_transfer .kv_pipe .pynccl_pipe import (
@@ -165,31 +163,7 @@ def send_kv_caches_and_hidden_states(
165163 num_prefill_tokens = model_input .attn_metadata .num_prefill_tokens
166164 start_layer = model_executable .model .start_layer
167165 end_layer = model_executable .model .end_layer
168-
169- model_config = model_executable .model .config
170- num_heads = int (model_config .num_key_value_heads / self .tp_size )
171- hidden_size = model_config .hidden_size
172- num_attention_heads = model_config .num_attention_heads
173-
174- # Deepseek's MLA (Multi-head Latent Attention) uses two different
175- # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
176- # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
177- # resulting in a kv_cache shape of [num_blks, blk_size, 1,
178- # kv_lora_rank + qk_rope_head_dim].
179- # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
180- # to a kv_cache shape of [2, num_blks, blk_size,
181- # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
182- # For more details, see vllm/attention/backends/mla/common.py.
183- if self .is_deepseek_mla and self .use_mla_opt :
184- head_size = model_config .kv_lora_rank + \
185- model_config .qk_rope_head_dim
186- num_heads = 1
187- elif self .is_deepseek_mla and not self .use_mla_opt :
188- head_size = model_config .qk_nope_head_dim + \
189- model_config .qk_rope_head_dim
190- else :
191- head_size = getattr (model_config , "head_dim" ,
192- int (hidden_size // num_attention_heads ))
166+ num_heads , head_size = self .kv_helper .get_model_args (model_executable )
193167
194168 # query_lens contains new KV caches that are added to vLLM.
195169 # so we will send them to decode instance
@@ -212,13 +186,8 @@ def send_kv_caches_and_hidden_states(
212186
213187 for layer_id in range (start_layer , end_layer ):
214188 kv_cache = kv_caches [layer_id - start_layer ]
215-
216- if self .is_deepseek_mla and self .use_mla_opt :
217- key_cache = kv_cache .reshape (- 1 , num_heads , head_size )
218- value_cache = kv_cache .reshape (- 1 , num_heads , head_size )
219- else :
220- key_cache = kv_cache [0 ].reshape (- 1 , num_heads , head_size )
221- value_cache = kv_cache [1 ].reshape (- 1 , num_heads , head_size )
189+ key_cache , value_cache = self .kv_helper .get_kv_from_cache (
190+ kv_cache , num_heads , head_size )
222191
223192 current_slot_mapping = slot_mapping_flat [start_pos :end_pos ]
224193
@@ -248,12 +217,12 @@ def recv_kv_caches_and_hidden_states(
248217 # and hidden states.
249218 bypass_model_exec = True
250219
251- model_config = model_executable .model .config
252-
253220 input_tokens_tensor = model_input .input_tokens
254221 seq_lens = model_input .attn_metadata .seq_lens
255222 num_prefill_tokens = model_input .attn_metadata .num_prefill_tokens
256223 slot_mapping = model_input .attn_metadata .slot_mapping .flatten ()
224+ start_layer = model_executable .model .start_layer
225+ end_layer = model_executable .model .end_layer
257226
258227 hidden_or_intermediate_states_for_one_req = []
259228
@@ -312,41 +281,19 @@ def recv_kv_caches_and_hidden_states(
312281 end_pos = start_pos + num_computed_tokens
313282
314283 # put received KV caches into paged memory
315- for i in range (model_executable .model .start_layer ,
316- model_executable .model .end_layer ):
317-
318- kv_cache = kv_caches [i - model_executable .model .start_layer ]
319- layer = model_executable .model .layers [i ]
320-
321- if self .is_deepseek_mla and self .use_mla_opt :
322- layer .self_attn .attn = layer .self_attn .mla_attn
323- k_c_normed_k_pe = keys [
324- i - model_executable .model .start_layer ].to (
325- kv_cache .device ).squeeze (1 )
326- k_c_normed = k_c_normed_k_pe [:, :model_config .kv_lora_rank ]
327- k_pe = k_c_normed_k_pe [:, model_config .kv_lora_rank :]
328- ops .concat_and_cache_mla (
329- k_c_normed ,
330- k_pe ,
331- kv_cache ,
332- slot_mapping [start_pos :end_pos ],
333- layer .self_attn .attn .kv_cache_dtype ,
334- layer .self_attn .attn ._k_scale ,
335- )
336- else :
337- key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
338- ops .reshape_and_cache_flash (
339- keys [i - model_executable .model .start_layer ].to (
340- key_cache .device ),
341- values [i - model_executable .model .start_layer ].to (
342- value_cache .device ),
343- key_cache ,
344- value_cache ,
345- slot_mapping [start_pos :end_pos ],
346- layer .self_attn .attn .kv_cache_dtype ,
347- layer .self_attn .attn ._k_scale ,
348- layer .self_attn .attn ._v_scale ,
349- )
284+ for cur_layer in range (start_layer , end_layer ):
285+
286+ layer_id = cur_layer - start_layer
287+ kv_cache = kv_caches [layer_id ]
288+ layer = model_executable .model .layers [cur_layer ]
289+
290+ # get remote kvcache
291+ remote_k , remote_v = keys [layer_id ], values [layer_id ]
292+
293+ self .kv_helper .put_kv_to_cache (model_executable , remote_k ,
294+ remote_v , layer , kv_cache ,
295+ slot_mapping , start_pos ,
296+ end_pos )
350297
351298 hidden_or_intermediate_states_for_one_req .append (hidden )
352299
0 commit comments