Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -144,7 +145,9 @@ def _get_sliding_window_configs(

class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
else AttentionCGSupport.ALWAYS

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
147 changes: 129 additions & 18 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, ClassVar, Optional

import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
Expand All @@ -18,10 +18,11 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
get_kv_cache_layout, get_per_layer_parameters,
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
PerLayerParameters, get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -223,21 +224,49 @@ def __post_init__(self):


class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.device = device
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
self._decode_wrapper = None # Wrapper for decode (general shape)

self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
self._decode_wrappers_cudagraph: dict[
int, BatchDecodeWithPagedKVCacheWrapper] = {}
self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size)

self._cascade_wrapper = None # Wrapper for cascade attention

# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None

self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
# Preparing persistent buffers
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.paged_kv_indices = torch.zeros(
max_num_pages, # max num pages possible
dtype=torch.int32,
device=self.device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=self.device)

def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
Expand All @@ -259,20 +288,49 @@ def _get_prefill_wrapper(self):
self._get_workspace_buffer(), get_kv_cache_layout())
return self._prefill_wrapper

def _get_decode_wrapper(self):
if self._decode_wrapper is None:
def _get_decode_wrapper(self,
batch_size: int,
use_cudagraph: bool = False):
if use_cudagraph:
decode_wrapper = self._decode_wrappers_cudagraph.get(
batch_size, None)
else:
decode_wrapper = self._decode_wrapper

if decode_wrapper is None:
num_qo_heads = (
self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config))
num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
self.vllm_config.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(

if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices
paged_kv_last_page_len = self.paged_kv_last_page_len[:
batch_size]
else:
paged_kv_indptr = None
paged_kv_indices = None
paged_kv_last_page_len = None
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
get_kv_cache_layout(),
use_cuda_graph=use_cudagraph,
paged_kv_indptr_buffer=paged_kv_indptr,
paged_kv_indices_buffer=paged_kv_indices,
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper

# save the decode wrapper
if use_cudagraph:
self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
else:
self._decode_wrapper = decode_wrapper

return decode_wrapper

def _get_cascade_wrapper(self):
if self._cascade_wrapper is None:
Expand Down Expand Up @@ -350,16 +408,34 @@ def _plan(self, num_prefills: int, num_decodes: int,
)

if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (self.enable_cuda_graph and pure_decode and
num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens_decode = (
self.vllm_config.pad_for_cudagraph(num_decodes))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need a new name here? can we just use num_decodes?

i.e.

num_decodes = (
                         self.vllm_config.pad_for_cudagraph(num_decodes))

so then we dont need the else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I keep these as two variables for filling the padding region. See the codes below.

                    # Carefully fulfill the padding region with reasonable value
                    # on cpu.
                    # Make sure paged_kv_indptr_cpu is not decreasing
                    self.paged_kv_indptr_cpu[1 + num_decodes:1 +
                                             num_input_tokens].fill_(
                                                 attn_metadata.
                                                 paged_kv_indptr_cpu[-1])
                    # Fill the remaining paged_kv_last_page_len_cpu with 1.
                    # This is because flashinfer treats 0 as a full page
                    # instead of empty.
                    self.paged_kv_last_page_len_cpu[
                        num_decodes:num_input_tokens].fill_(1)

else:
num_input_tokens_decode = num_decodes

attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens_decode, use_cudagraph)
if not FlashInferBackend.use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
# TODO: Override flashinfer's plan function to avoid some
# host-to-device copy overhead.
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[:num_decodes],
# NOTE: Use the persistent buffer with padding length,
# instead of the same address but chunked length buffers
# in the atten_metadata. This is to be compatible with
# FlashInfer's decode_wrapper when using cudagraph.
self.paged_kv_indptr[:num_input_tokens_decode + 1],
self.paged_kv_indices if use_cudagraph else \
attn_metadata.paged_kv_indices,
self.paged_kv_last_page_len[:num_input_tokens_decode],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
Expand All @@ -378,6 +454,7 @@ def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashInferMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
Expand Down Expand Up @@ -421,17 +498,31 @@ def build(self,
device=block_table_tensor.device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[mask]
num_actual_pages = paged_kv_indices.size(0)
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
non_blocking=True)
self.paged_kv_indices[num_actual_pages:].fill_(-1)

paged_kv_indptr = torch.cat([
torch.zeros(1,
dtype=block_table_bounds.dtype,
device=block_table_bounds.device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
non_blocking=True)
# make sure self.paged_kv_indptr is not decreasing
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])

paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
self.paged_kv_last_page_len[:num_reqs].copy_(paged_kv_last_page_len,
non_blocking=True)
# Fill the remaining paged_kv_last_page_len with 1. This is because
# flashinfer treats 0 as a full page instead of empty.
self.paged_kv_last_page_len[num_reqs:].fill_(1)

cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
Expand All @@ -441,9 +532,9 @@ def build(self,
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
qo_indptr=qo_indptr,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
paged_kv_indptr=self.paged_kv_indptr[:1 + num_reqs],
paged_kv_indices=self.paged_kv_indices[:num_actual_pages],
paged_kv_last_page_len=self.paged_kv_last_page_len[:num_reqs],
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config),
num_kv_heads=self.kv_cache_spec.num_kv_heads,
Expand Down Expand Up @@ -471,6 +562,26 @@ def build(self,

return attn_metadata

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m = common_attn_metadata

assert m.num_reqs == m.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."

m.max_query_len = 1 # decode-only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You shouldn't need to set this. You can add it to your decode_only assert on the previous line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a common practice now (also see this part for FlashMLA). As the attn_metadate passed from dummy run have max_query_len=num_tokens currently.


return self.build(0, m)

def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1

def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

logger = init_logger(__name__)
Expand Down Expand Up @@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

# yapf: enable
Expand Down Expand Up @@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):


class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec

Expand Down Expand Up @@ -57,7 +58,8 @@ class TritonAttentionMetadata:

class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
18 changes: 17 additions & 1 deletion vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -63,9 +64,24 @@ class CommonAttentionMetadata:
M = TypeVar("M")


class AttentionCGSupport(enum.Enum):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""

NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""


class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER

@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
Expand Down
22 changes: 15 additions & 7 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
is_pin_memory_available, round_up)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_local_attention_virtual_batches)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec,
Expand Down Expand Up @@ -2437,12 +2437,20 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
self.device,
)

if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
f"full_cuda_graph or use a different attention backend.")
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]

self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)
Expand Down
Loading