Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""
MooncakeStore Connector for Distributed Machine Learning Inference

The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
database-style KVStore.
Expand All @@ -11,9 +10,10 @@

import torch

from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.utils import (
model_aware_kv_ops_helper as kv_helper)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors

Expand All @@ -32,8 +32,7 @@ def __init__(
config: VllmConfig,
):
self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size

self.kv_helper = kv_helper(config)
self.local_tp_rank = local_rank

# Init kv_store
Expand Down Expand Up @@ -80,12 +79,7 @@ def send_kv_caches_and_hidden_states(
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer

model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
num_heads, head_size = self.kv_helper.get_model_args(model_executable)

for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
Expand All @@ -97,10 +91,8 @@ def send_kv_caches_and_hidden_states(

for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]

key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)

key_cache, value_cache = self.kv_helper.get_kv_from_cache(
kv_cache, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]

keys.append(key_cache[current_slot_mapping].unsqueeze(0))
Expand Down Expand Up @@ -173,22 +165,15 @@ def recv_kv_caches_and_hidden_states(
layer = model_executable.model.layers[layer_id]
# get kvcache object
kv_cache = kv_caches[layer_id - start_layer]
key_cache, value_cache = kv_cache[0], kv_cache[1]
# get remote kvcache

# get remote kvcache
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
layer_id]
# use ops.reshape_and_cache_flash to put kv into kvcache
ops.reshape_and_cache_flash(
remote_k.to(key_cache.device),
remote_v.to(value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)

self.kv_helper.put_kv_to_cache(model_executable, remote_k,
remote_v, layer, kv_cache,
slot_mapping, start_pos,
end_pos)

hidden_or_intermediate_states_for_one_req.append(hidden)

Expand Down
95 changes: 21 additions & 74 deletions vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.utils import (
model_aware_kv_ops_helper as kv_helper)
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.logger import init_logger
Expand All @@ -37,9 +37,7 @@ def __init__(
):

self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
self.kv_helper = kv_helper(config)

if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
Expand Down Expand Up @@ -165,31 +163,7 @@ def send_kv_caches_and_hidden_states(
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer

model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads

# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + \
model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + \
model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))
num_heads, head_size = self.kv_helper.get_model_args(model_executable)

# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
Expand All @@ -212,13 +186,8 @@ def send_kv_caches_and_hidden_states(

for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]

if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
kv_cache, num_heads, head_size)

current_slot_mapping = slot_mapping_flat[start_pos:end_pos]

Expand Down Expand Up @@ -248,12 +217,12 @@ def recv_kv_caches_and_hidden_states(
# and hidden states.
bypass_model_exec = True

model_config = model_executable.model.config

input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer

hidden_or_intermediate_states_for_one_req = []

Expand Down Expand Up @@ -312,41 +281,19 @@ def recv_kv_caches_and_hidden_states(
end_pos = start_pos + num_computed_tokens

# put received KV caches into paged memory
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):

kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]

if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
k_c_normed_k_pe = keys[
i - model_executable.model.start_layer].to(
kv_cache.device).squeeze(1)
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
ops.concat_and_cache_mla(
k_c_normed,
k_pe,
kv_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
for cur_layer in range(start_layer, end_layer):

layer_id = cur_layer - start_layer
kv_cache = kv_caches[layer_id]
layer = model_executable.model.layers[cur_layer]

# get remote kvcache
remote_k, remote_v = keys[layer_id], values[layer_id]

self.kv_helper.put_kv_to_cache(model_executable, remote_k,
remote_v, layer, kv_cache,
slot_mapping, start_pos,
end_pos)

hidden_or_intermediate_states_for_one_req.append(hidden)

Expand Down
90 changes: 90 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
"""
KV cache helper for store.
"""
import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.logger import init_logger

logger = init_logger(__name__)


class model_aware_kv_ops_helper:

def __init__(self, config: VllmConfig):
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
self.tp_size = config.parallel_config.tensor_parallel_size

def get_model_args(self, model_executable: torch.nn.Module):

model_config = model_executable.model.config
self.model_executable = model_executable
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads

# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + \
model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + \
model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))

return num_heads, head_size

def get_kv_from_cache(self, kv_cache, num_heads, head_size):
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
return key_cache, value_cache

def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
layer, kv_cache, slot_mapping, start_pos, end_pos):

model_config = model_executable.model.config

if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
k_c_normed_k_pe = keys.squeeze(1)
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
ops.concat_and_cache_mla(
k_c_normed.to(kv_cache.device),
k_pe.to(kv_cache.device),
kv_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys.to(key_cache.device),
values.to(value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)