Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2
GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
22 changes: 13 additions & 9 deletions tests/v1/cudagraph/test_cudagraph_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import replace
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -132,36 +133,39 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):

# Test dispatch logic
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(
num_tokens=8,
uniform=False,
)
# FULL mode uses exact keys with num_reqs set
desc_full_with_reqs = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
# PIECEWISE mode uses relaxed keys with num_reqs=None
desc_piecewise = BatchDescriptor(num_tokens=8, num_reqs=None, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False
)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
assert key == desc_full_with_reqs
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact
assert key == desc_piecewise
else:
assert rt_mode == CUDAGraphMode.NONE

# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
desc_non_uniform = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=True, has_lora=False
)
if cudagraph_mode_str == "FULL":
# Pure FULL mode uses non-uniform keys for all batches
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
assert key == desc_non_uniform
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
# These modes have separate uniform decode keys
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
assert key == replace(desc_uniform_exact, num_reqs=None, uniform=False)
else:
assert rt_mode == CUDAGraphMode.NONE

Expand All @@ -180,7 +184,7 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):

if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
assert key == replace(desc_full_exact, num_reqs=None, uniform=False)
else:
assert rt_mode == CUDAGraphMode.NONE

Expand Down
18 changes: 3 additions & 15 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, NamedTuple
from typing import Any

import torch

Expand All @@ -26,7 +26,8 @@
batchsize_forward_time: defaultdict = defaultdict(list)


class BatchDescriptor(NamedTuple):
@dataclass(frozen=True)
class BatchDescriptor:
"""
Batch descriptor for cudagraph dispatching. We should keep the num of
items as minimal as possible to properly and uniquely describe the padded
Expand Down Expand Up @@ -56,19 +57,6 @@ class BatchDescriptor(NamedTuple):
to be properly captured.
"""

def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
Return a relaxed version of current batch descriptor that is still compatible
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens,
num_reqs=None,
uniform=False,
has_lora=self.has_lora,
num_active_loras=self.num_active_loras,
)


def _compute_sp_num_tokens(
num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
Expand Down
13 changes: 11 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv
from vllm.utils.math_utils import cdiv, round_up
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionMetadataBuilder,
Expand Down Expand Up @@ -310,8 +310,17 @@ def __init__(
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size

if self.use_full_cuda_graph and self.aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for the tile_count_semaphore (synchronization).
# The 4 slots per batch element (num_prepare_batch_vectors) are:
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
max_batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
self.max_cudagraph_size or 0,
)
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32,
device=self.device,
)
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
Expand Down Expand Up @@ -129,8 +130,17 @@ def __init__(
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size

if self.use_full_cuda_graph and self.fa_aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for the tile_count_semaphore (synchronization).
# The 4 slots per batch element (num_prepare_batch_vectors) are:
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
max_batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
self.max_cudagraph_size or 0,
)
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32,
device=self.device,
)
Expand Down
37 changes: 21 additions & 16 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import replace
from itertools import product

from vllm.config import CUDAGraphMode, VllmConfig
Expand Down Expand Up @@ -180,12 +181,14 @@ def initialize_cudagraph_keys(
for bs, num_active_loras in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor(
bs, False, num_active_loras > 0, num_active_loras
).relax_for_mixed_batch_cudagraphs(),
batch_desc = self._create_padded_batch_descriptor(
bs, False, num_active_loras > 0, num_active_loras
)
# Only relax for PIECEWISE mode. FULL mode needs exact num_reqs
# because FA3's scheduler_metadata computation depends on it.
if cudagraph_mode.mixed_mode() == CUDAGraphMode.PIECEWISE:
batch_desc = replace(batch_desc, num_reqs=None, uniform=False)
self.add_cudagraph_key(cudagraph_mode.mixed_mode(), batch_desc)

# if decode cudagraph mode is FULL, and we don't already have mixed
# mode full cudagraphs then add them here.
Expand Down Expand Up @@ -264,21 +267,23 @@ def dispatch(
batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora, effective_num_active_loras
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()

if not disable_full:
# check if key exists for full cudagraph
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc

# otherwise, check if the relaxed key exists
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, relaxed_batch_desc
# check if key exists for full cudagraph
# For pure FULL mode, keys are registered with uniform=False.
batch_desc_to_check = batch_desc
if self.cudagraph_mode == CUDAGraphMode.FULL:
batch_desc_to_check = replace(batch_desc, uniform=False)
if (
not disable_full
and batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]
):
return CUDAGraphMode.FULL, batch_desc_to_check

# also check if the relaxed key exists for more "general"
# piecewise cudagraph
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, batch_desc_to_check

# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
Expand Down