Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions vllm/v1/worker/gpu/attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheConfig,
Expand Down Expand Up @@ -143,6 +144,28 @@ def build_slot_mappings_by_layer(
return slot_mappings_by_layer


def prepare_dcp_local_seq_lens(
dcp_local_seq_lens: torch.Tensor,
seq_lens: torch.Tensor,
num_reqs: int,
dcp_size: int,
dcp_rank: int,
cp_kv_cache_interleave_size: int,
) -> None:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if dcp_size <= 1:
return

local_seq_lens = get_dcp_local_seq_lens(
seq_lens[:num_reqs],
dcp_size=dcp_size,
dcp_rank=dcp_rank,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
dcp_local_seq_lens[:num_reqs].copy_(local_seq_lens, non_blocking=True)
dcp_local_seq_lens[num_reqs:].zero_()


def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
Expand All @@ -155,9 +178,13 @@ def build_attn_metadata(
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
dcp_local_seq_lens: torch.Tensor | None = None,
) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs]

if dcp_local_seq_lens is not None:
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]

attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
Expand All @@ -175,6 +202,7 @@ def build_attn_metadata(
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=True,
dcp_local_seq_lens=dcp_local_seq_lens,
)

attn_metadata_builder = attn_metadata_builders[i]
Expand Down
46 changes: 43 additions & 3 deletions vllm/v1/worker/gpu/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from vllm.distributed import get_dcp_group
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
Expand All @@ -18,19 +19,36 @@ def __init__(
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
cp_kv_cache_interleave_size: int = 1,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
assert cp_kv_cache_interleave_size >= 1
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size

try:
dcp = get_dcp_group()
self.dcp_world_size, self.dcp_rank = dcp.world_size, dcp.rank_in_group
except AssertionError:
self.dcp_world_size, self.dcp_rank = 1, 0
# TODO(wentao): PCP supprot
self.total_cp_world_size = self.dcp_world_size
self.total_cp_rank = self.dcp_rank

self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
# with DCP, a request's KV is sharded across
# ranks, so one physical block on this rank
# corresponds to `block_size * total_cp_world_size`
# tokens in the global (unsharded) sequence.
virtual_block_size = block_size * self.total_cp_world_size
max_num_blocks = cdiv(self.max_model_len, virtual_block_size)
block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks),
dtype=torch.int32,
Expand Down Expand Up @@ -131,6 +149,9 @@ def compute_slot_mappings(
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
TOTAL_CP_WORLD_SIZE=self.total_cp_world_size,
TOTAL_CP_RANK=self.total_cp_rank,
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore
)
Expand Down Expand Up @@ -183,6 +204,9 @@ def _compute_slot_mappings_kernel(
block_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
TOTAL_CP_WORLD_SIZE: tl.constexpr,
TOTAL_CP_RANK: tl.constexpr,
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
PAD_ID: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -201,18 +225,34 @@ def _compute_slot_mappings_kernel(
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
block_size = tl.load(block_sizes + group_id)
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE

req_state_idx = tl.load(idx_mapping + batch_idx)
start_idx = tl.load(query_start_loc + batch_idx)
end_idx = tl.load(query_start_loc + batch_idx + 1)
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // block_size
block_indices = positions // virtual_block_size
block_numbers = tl.load(
block_table_ptr + req_state_idx * block_table_stride + block_indices
)
slot_ids = block_numbers * block_size + positions % block_size
virtual_block_offsets = positions - block_indices * virtual_block_size

# determine whether the token is stored on this CP rank.
is_local = (
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
# mapping virture block offsets to local block offsets.
local_block_offsets = (
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
)

# physical slot index
slot_ids = block_numbers * block_size + local_block_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)


Expand Down
20 changes: 20 additions & 0 deletions vllm/v1/worker/gpu/cudagraph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed import get_dcp_group
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
prepare_dcp_local_seq_lens,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
Expand Down Expand Up @@ -257,6 +259,23 @@ def prepare_inputs_to_capture(
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0

try:
dcp_group = get_dcp_group()
dcp_world_size = dcp_group.world_size
dcp_rank = dcp_group.rank_in_group
except AssertionError:
dcp_world_size = 1
dcp_rank = 0
if dcp_world_size > 1:
prepare_dcp_local_seq_lens(
input_buffers.dcp_local_seq_lens,
input_buffers.seq_lens,
num_reqs,
dcp_size=dcp_world_size,
dcp_rank=dcp_rank,
cp_kv_cache_interleave_size=block_tables.cp_kv_cache_interleave_size,
)

input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
slot_mappings_by_layer = build_slot_mappings_by_layer(
Expand All @@ -275,5 +294,6 @@ def prepare_inputs_to_capture(
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
)
return attn_metadata, slot_mappings_by_layer
4 changes: 4 additions & 0 deletions vllm/v1/worker/gpu/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def __init__(
max_num_reqs + 1, dtype=torch.int32, device=device
)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
# DCP: per-request local seq_lens buffer
self.dcp_local_seq_lens = torch.zeros(
max_num_reqs, dtype=torch.int32, device=device
)


@dataclass
Expand Down
22 changes: 22 additions & 0 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import (
get_dcp_group,
get_pp_group,
prepare_communication_buffer_for_model,
)
Expand All @@ -24,13 +25,15 @@
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
get_kv_cache_spec,
init_attn_backend,
init_kv_cache,
prepare_dcp_local_seq_lens,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
Expand Down Expand Up @@ -248,11 +251,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
cp_kv_cache_interleave_size=(
self.parallel_config.cp_kv_cache_interleave_size
),
)

self.attn_backends, self.attn_metadata_builders = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device
)
check_attention_cp_compatibility(self.vllm_config)
if self.do_spec_decode:
# HACK(woosuk)
self.speculator.set_attn(
Expand Down Expand Up @@ -294,6 +301,7 @@ def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
)
input_batch.attn_metadata = attn_metadata
input_batch.slot_mappings = slot_mappings_by_layer
Expand Down Expand Up @@ -627,6 +635,19 @@ def prepare_inputs(
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]

dcp_size = self.parallel_config.decode_context_parallel_size
if dcp_size > 1:
prepare_dcp_local_seq_lens(
self.input_buffers.dcp_local_seq_lens,
seq_lens,
num_reqs,
dcp_size=dcp_size,
dcp_rank=get_dcp_group().rank_in_group,
cp_kv_cache_interleave_size=(
self.parallel_config.cp_kv_cache_interleave_size
),
)

# Prepare M-RoPE positions.
if self.uses_mrope:
self.mrope_states.prepare_mrope_positions(
Expand Down Expand Up @@ -674,6 +695,7 @@ def prepare_inputs(
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
)

input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
Expand Down