diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py index c5135dab23eb..7b26aec23239 100644 --- a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ MooncakeStore Connector for Distributed Machine Learning Inference - The MooncakeStoreConnector transfers KV caches between prefill vLLM workers (KV cache producer) and decode vLLM workers (KV cache consumer) using a database-style KVStore. @@ -11,9 +10,10 @@ import torch -from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) from vllm.logger import init_logger from vllm.sequence import IntermediateTensors @@ -32,8 +32,7 @@ def __init__( config: VllmConfig, ): self.config = config.kv_transfer_config - self.tp_size = config.parallel_config.tensor_parallel_size - + self.kv_helper = kv_helper(config) self.local_tp_rank = local_rank # Init kv_store @@ -80,12 +79,7 @@ def send_kv_caches_and_hidden_states( slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer - - model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - head_size = int(hidden_size / num_attention_heads) + num_heads, head_size = self.kv_helper.get_model_args(model_executable) for idx, slen in enumerate(seq_lens): start_pos = sum(seq_lens[:idx]) @@ -97,10 +91,8 @@ def send_kv_caches_and_hidden_states( for layer_id in range(start_layer, end_layer): kv_cache = kv_caches[layer_id - start_layer] - - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) current_slot_mapping = slot_mapping_flat[start_pos:end_pos] keys.append(key_cache[current_slot_mapping].unsqueeze(0)) @@ -173,22 +165,15 @@ def recv_kv_caches_and_hidden_states( layer = model_executable.model.layers[layer_id] # get kvcache object kv_cache = kv_caches[layer_id - start_layer] - key_cache, value_cache = kv_cache[0], kv_cache[1] - # get remote kvcache + # get remote kvcache remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ layer_id] - # use ops.reshape_and_cache_flash to put kv into kvcache - ops.reshape_and_cache_flash( - remote_k.to(key_cache.device), - remote_v.to(value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) hidden_or_intermediate_states_for_one_req.append(hidden) diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 49b97d7b5889..0464a7585138 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -12,10 +12,10 @@ import torch -import vllm.envs as envs -from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( SimpleBuffer) from vllm.logger import init_logger @@ -37,9 +37,7 @@ def __init__( ): self.config = config.kv_transfer_config - self.tp_size = config.parallel_config.tensor_parallel_size - self.is_deepseek_mla = config.model_config.is_deepseek_mla - self.use_mla_opt = not envs.VLLM_MLA_DISABLE + self.kv_helper = kv_helper(config) if self.config.kv_connector == "PyNcclConnector": from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( @@ -165,31 +163,7 @@ def send_kv_caches_and_hidden_states( num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer - - model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - - # Deepseek's MLA (Multi-head Latent Attention) uses two different - # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. - # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, - # resulting in a kv_cache shape of [num_blks, blk_size, 1, - # kv_lora_rank + qk_rope_head_dim]. - # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading - # to a kv_cache shape of [2, num_blks, blk_size, - # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. - # For more details, see vllm/attention/backends/mla/common.py. - if self.is_deepseek_mla and self.use_mla_opt: - head_size = model_config.kv_lora_rank + \ - model_config.qk_rope_head_dim - num_heads = 1 - elif self.is_deepseek_mla and not self.use_mla_opt: - head_size = model_config.qk_nope_head_dim + \ - model_config.qk_rope_head_dim - else: - head_size = getattr(model_config, "head_dim", - int(hidden_size // num_attention_heads)) + num_heads, head_size = self.kv_helper.get_model_args(model_executable) # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance @@ -212,13 +186,8 @@ def send_kv_caches_and_hidden_states( for layer_id in range(start_layer, end_layer): kv_cache = kv_caches[layer_id - start_layer] - - if self.is_deepseek_mla and self.use_mla_opt: - key_cache = kv_cache.reshape(-1, num_heads, head_size) - value_cache = kv_cache.reshape(-1, num_heads, head_size) - else: - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) current_slot_mapping = slot_mapping_flat[start_pos:end_pos] @@ -248,12 +217,12 @@ def recv_kv_caches_and_hidden_states( # and hidden states. bypass_model_exec = True - model_config = model_executable.model.config - input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer hidden_or_intermediate_states_for_one_req = [] @@ -312,41 +281,19 @@ def recv_kv_caches_and_hidden_states( end_pos = start_pos + num_computed_tokens # put received KV caches into paged memory - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - - kv_cache = kv_caches[i - model_executable.model.start_layer] - layer = model_executable.model.layers[i] - - if self.is_deepseek_mla and self.use_mla_opt: - layer.self_attn.attn = layer.self_attn.mla_attn - k_c_normed_k_pe = keys[ - i - model_executable.model.start_layer].to( - kv_cache.device).squeeze(1) - k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] - k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] - ops.concat_and_cache_mla( - k_c_normed, - k_pe, - kv_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - ) - else: - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - keys[i - model_executable.model.start_layer].to( - key_cache.device), - values[i - model_executable.model.start_layer].to( - value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) + for cur_layer in range(start_layer, end_layer): + + layer_id = cur_layer - start_layer + kv_cache = kv_caches[layer_id] + layer = model_executable.model.layers[cur_layer] + + # get remote kvcache + remote_k, remote_v = keys[layer_id], values[layer_id] + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) hidden_or_intermediate_states_for_one_req.append(hidden) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py new file mode 100644 index 000000000000..0b0ce9828a74 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +KV cache helper for store. +""" +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class model_aware_kv_ops_helper: + + def __init__(self, config: VllmConfig): + self.is_deepseek_mla = config.model_config.is_deepseek_mla + self.use_mla_opt = not envs.VLLM_MLA_DISABLE + self.tp_size = config.parallel_config.tensor_parallel_size + + def get_model_args(self, model_executable: torch.nn.Module): + + model_config = model_executable.model.config + self.model_executable = model_executable + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + + # Deepseek's MLA (Multi-head Latent Attention) uses two different + # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. + # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, + # resulting in a kv_cache shape of [num_blks, blk_size, 1, + # kv_lora_rank + qk_rope_head_dim]. + # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading + # to a kv_cache shape of [2, num_blks, blk_size, + # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. + # For more details, see vllm/attention/backends/mla/common.py. + if self.is_deepseek_mla and self.use_mla_opt: + head_size = model_config.kv_lora_rank + \ + model_config.qk_rope_head_dim + num_heads = 1 + elif self.is_deepseek_mla and not self.use_mla_opt: + head_size = model_config.qk_nope_head_dim + \ + model_config.qk_rope_head_dim + else: + head_size = getattr(model_config, "head_dim", + int(hidden_size // num_attention_heads)) + + return num_heads, head_size + + def get_kv_from_cache(self, kv_cache, num_heads, head_size): + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + return key_cache, value_cache + + def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, + layer, kv_cache, slot_mapping, start_pos, end_pos): + + model_config = model_executable.model.config + + if self.is_deepseek_mla and self.use_mla_opt: + layer.self_attn.attn = layer.self_attn.mla_attn + k_c_normed_k_pe = keys.squeeze(1) + k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + ops.concat_and_cache_mla( + k_c_normed.to(kv_cache.device), + k_pe.to(kv_cache.device), + kv_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + ) + else: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys.to(key_cache.device), + values.to(value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + )