Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionCGSupport
from vllm.v1.kv_cache_interface import KVCacheConfig
else:
VllmConfig = object

Expand Down Expand Up @@ -1241,6 +1243,152 @@ def is_custom_op_enabled(self, op: str) -> bool:
assert "none" in self.custom_ops
return f"+{op}" in self.custom_ops

def resolve_cudagraph_mode_and_sizes(
self,
min_cg_support: "AttentionCGSupport",
min_cg_attn_backend: str | None,
uniform_decode_query_len: int = 1,
tensor_parallel_size: int = 1,
kv_cache_config: "KVCacheConfig | None" = None,
max_num_reqs: int | None = None,
is_profiling: bool = False,
) -> CUDAGraphMode:
from vllm.v1.attention.backend import AttentionCGSupport

cudagraph_mode = self.cudagraph_mode
if cudagraph_mode is None or cudagraph_mode == CUDAGraphMode.NONE:
self.cudagraph_mode = CUDAGraphMode.NONE
return CUDAGraphMode.NONE

# Check cudagraph for mixed batch is supported
if (
cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL
and min_cg_support != AttentionCGSupport.ALWAYS
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_attn_backend} backend (support: "
f"{min_cg_support})"
)
if min_cg_support == AttentionCGSupport.NEVER:
# if not supported any full cudagraphs, just raise it.
msg += (
"; please try cudagraph_mode=PIECEWISE, and "
"make sure compilation mode is VLLM_COMPILE"
)
raise ValueError(msg)

# attempt to resolve the full cudagraph related mode
if self.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
logger.warning(msg)

# check that if we are doing decode full-cudagraphs it is supported
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and min_cg_support == AttentionCGSupport.NEVER
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_attn_backend} backend (support: "
f"{min_cg_support})"
)
if self.mode == CompilationMode.VLLM_COMPILE and (
self.splitting_ops_contain_attention()
or self.use_inductor_graph_partition
):
msg += (
"; setting cudagraph_mode=PIECEWISE because "
"attention is compiled piecewise"
)
cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
msg += (
"; setting cudagraph_mode=NONE because "
"attention is not compiled piecewise"
)
cudagraph_mode = CUDAGraphMode.NONE
logger.warning(msg)

# check that if we are doing spec-decode + decode full-cudagraphs it is
# supported
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and uniform_decode_query_len > 1
and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_cg_attn_backend} (support: {min_cg_support})"
)
if self.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=PIECEWISE"
cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE"
cudagraph_mode = CUDAGraphMode.NONE
logger.warning(msg)

# double check that we can support full cudagraph if they are requested
# even after automatic downgrades
if (
cudagraph_mode.has_full_cudagraphs()
and min_cg_support == AttentionCGSupport.NEVER
):
raise ValueError(
f"CUDAGraphMode.{cudagraph_mode.name} is not "
f"supported with {min_cg_attn_backend} backend ("
f"support:{min_cg_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation mode is VLLM_COMPILE"
)

# Adjust cudagraph sizes to be a multiple of uniform_decode_query_len
# to avoid: https://github.com/vllm-project/vllm/issues/28207 and temp-fix:
# https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
# Will be removed in the near future when we have separate cudagraph capture
# sizes for decode and mixed prefill-decode.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and uniform_decode_query_len > 1
):
self.adjust_cudagraph_sizes_for_spec_decode(
uniform_decode_query_len,
tensor_parallel_size,
)

# For Mamba models with FULL decode cudagraphs, each decode
# sequence needs one Mamba cache block. The decode cudagraph
# dispatcher already caps batch sizes at max_num_seqs, so we just
# need to verify that enough blocks exist. Raising here instead
# of silently capping cudagraph_capture_sizes avoids unintended
# restrictions on PIECEWISE (prefill) cudagraphs.
# See: https://github.com/vllm-project/vllm/issues/34094
if (
kv_cache_config is not None
and max_num_reqs is not None
and cudagraph_mode.has_full_cudagraphs()
and not is_profiling
and kv_cache_config.has_mamba_layers
and max_num_reqs > kv_cache_config.num_blocks
):
raise ValueError(
f"max_num_seqs ({max_num_reqs}) exceeds available Mamba cache "
f"blocks ({kv_cache_config.num_blocks}). Each decode sequence "
"requires one Mamba cache block, so CUDA graph capture cannot "
"proceed. Please lower max_num_seqs to at most "
f"{kv_cache_config.num_blocks} or increase "
"gpu_memory_utilization."
)

self.cudagraph_mode = cudagraph_mode
return cudagraph_mode

def adjust_cudagraph_sizes_for_spec_decode(
self, uniform_decode_query_len: int, tensor_parallel_size: int
):
Expand Down
44 changes: 39 additions & 5 deletions vllm/v1/worker/gpu/attn_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, cast

import numpy as np
import torch

from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheConfig,
Expand All @@ -18,6 +23,12 @@
from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache


@dataclass(frozen=True)
class AttentionCGSupportInfo:
min_cg_support: AttentionCGSupport = AttentionCGSupport.ALWAYS
min_cg_attn_backend: str | None = None


def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
kv_cache_spec: dict[str, KVCacheSpec] = {}
layer_type = cast(type[Any], AttentionLayerBase)
Expand All @@ -34,10 +45,17 @@ def init_attn_backend(
vllm_config: VllmConfig,
device: torch.device,
active_layer_names: set[str] | None = None,
):
) -> tuple[
dict[str, type[AttentionBackend]],
list[list[AttentionGroup]],
AttentionCGSupportInfo,
]:
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_groups: list[list[AttentionGroup]] = []
attn_backend_workspace: torch.Tensor | None = None
# Find minimum cudagraph support across all attention backends
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_attn_backend = None
for kv_cache_group_id, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups
):
Expand Down Expand Up @@ -86,8 +104,24 @@ def init_attn_backend(
else:
if hasattr(builder, "set_workspace_buffer"):
builder.set_workspace_buffer(attn_backend_workspace)
# Check cudagraph support for the attention backend
cg_support = builder.get_cudagraph_support(
vllm_config,
cast(AttentionSpec, kv_cache_group_spec.kv_cache_spec),
)
if cg_support.value < min_cg_support.value:
min_cg_support = cg_support
min_cg_attn_backend = attn_backend.__name__
attn_groups.append(groups)
return attn_backends, attn_groups

return (
attn_backends,
attn_groups,
AttentionCGSupportInfo(
min_cg_support=min_cg_support,
min_cg_attn_backend=min_cg_attn_backend,
),
)


def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
Expand All @@ -110,7 +144,7 @@ def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend],
attn_backends: dict[str, type[AttentionBackend]],
cache_dtype: str,
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
Expand Down Expand Up @@ -158,7 +192,7 @@ def init_kv_cache(
runner_kv_caches: list[torch.Tensor],
forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
attn_backends: dict[str, type[AttentionBackend]],
device: torch.device,
cache_dtype: str,
) -> dict[str, torch.Tensor]:
Expand Down
44 changes: 43 additions & 1 deletion vllm/v1/worker/gpu/dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N


def sync_cudagraph_and_dp_padding(
cudagraph_manager: CudaGraphManager,
cudagraph_manager: CudaGraphManager | None,
desired_batch_desc: BatchExecutionDescriptor,
num_tokens: int,
num_reqs: int,
Expand Down Expand Up @@ -61,6 +61,10 @@ def sync_cudagraph_and_dp_padding(
num_reqs=num_reqs,
), num_tokens_across_dp

assert cudagraph_manager is not None, (
"cudagraph_manager should only be None during profile run, "
"where synced_cg_mode must be NONE across all DP ranks"
)
synced_num_tokens = int(num_tokens_across_dp.max().item())
synced_uniform_token_count = uniform_token_counts_across_dp[0]
# If ranks disagree on the uniform token count, or its 0 (means None) set to None
Expand All @@ -79,3 +83,41 @@ def sync_cudagraph_and_dp_padding(
num_tokens_across_dp[:] = synced_desc.num_tokens

return synced_desc, num_tokens_across_dp


def dispatch_cg_and_sync_dp(
cudagraph_manager: CudaGraphManager | None,
num_reqs: int,
num_tokens: int,
uniform_token_count: int | None,
dp_size: int,
dp_rank: int,
need_eager: bool = False,
) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
if need_eager:
batch_desc = BatchExecutionDescriptor(
cg_mode=CUDAGraphMode.NONE,
num_tokens=num_tokens,
num_reqs=num_reqs,
)
else:
assert cudagraph_manager is not None, (
"cudagraph_manager should only be None during profile run, "
"where need_eager must be True"
)
batch_desc = cudagraph_manager.dispatch(
num_reqs, num_tokens, uniform_token_count
)

if dp_size == 1:
return batch_desc, None

return sync_cudagraph_and_dp_padding(
cudagraph_manager,
batch_desc,
num_tokens,
num_reqs,
uniform_token_count,
dp_size,
dp_rank,
)
Loading
Loading