Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions vllm/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import importlib.metadata
import os
import platform
import random
import threading
from collections.abc import Callable, Collection
Expand Down Expand Up @@ -67,6 +68,11 @@
T = TypeVar("T")


# Pin memory in non-WSL case.
# Logic duplicated here for now to avoid circular import.
PIN_MEMORY = "microsoft" not in " ".join(platform.uname()).lower()


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return (
kv_cache_dtype.startswith("fp8")
Expand Down Expand Up @@ -602,12 +608,12 @@ def create_kv_caches_with_random(
def async_tensor_h2d(
data: list,
dtype: torch.dtype,
target_device: str | torch.device,
pin_memory: bool,
device: str | torch.device,
pin_memory: bool = PIN_MEMORY,
) -> torch.Tensor:
"""Asynchronously create a tensor and copy it from host to device."""
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
return t.to(device=device, non_blocking=True)


def make_ndarray_with_pad(
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,8 @@ def build(
paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[
prefill_start : num_reqs + 1
]
paged_kv_indptr_prefill_gpu[0] = 0
# Assign to slice to avoid cpu sync.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a lot of real black magic in this pr

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda_tensor[0] = 0 uses copy_ which does a sync, cuda_tensor[:1] = 0 uses fill_ which doesn't :)

paged_kv_indptr_prefill_gpu[:1] = 0
torch.cumsum(
num_blocks_per_req,
dim=0,
Expand Down
48 changes: 30 additions & 18 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,16 @@
flex_attention_compiled = torch.compile(flex_attention, fullgraph=True)


def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
device = offsets.device
counts = offsets[1:] - offsets[:-1]
return torch.repeat_interleave(
torch.arange(len(counts), device=device, dtype=torch.int32), counts
def _offsets_to_doc_ids_tensor(
offsets_cpu: torch.Tensor, device: torch.device
) -> torch.Tensor:
# Build on CPU (so `repeat_interleave` doesn't force a GPU->CPU sync to
# learn the data-dependent output length) and upload non-blocking.
counts = offsets_cpu[1:] - offsets_cpu[:-1]
doc_ids = torch.repeat_interleave(
torch.arange(len(counts), dtype=torch.int32), counts
)
return doc_ids.to(device, non_blocking=True)


def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
Expand Down Expand Up @@ -290,11 +294,13 @@ def unique_static_unsorted(
keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)) # [B, N]

# ── left-pack uniques into a fresh tensor ───────────────────────────
# Route non-kept entries to a garbage slot at column N so we can do a
# single scatter rather than using torch.nonzero (which would force a
# GPU->CPU sync to enumerate kept positions).
dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go
packed_flat = torch.full_like(x_flat, pad_val)

rows, src_cols = torch.nonzero(keep, as_tuple=True)
packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols]
dest_pos = torch.where(keep, dest_pos, N)
packed_extended = torch.full((B, N + 1), pad_val, device=device, dtype=x_flat.dtype)
packed_flat = packed_extended.scatter_(1, dest_pos, x_flat)[:, :N]

# ── restore original layout ─────────────────────────────────────────
packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim)
Expand Down Expand Up @@ -346,6 +352,9 @@ class FlexAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
# CPU-resident copy of query_start_loc used to derive doc_ids without a
# GPU->CPU sync from repeat_interleave's data-dependent output size.
query_start_loc_cpu: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
Expand Down Expand Up @@ -452,12 +461,7 @@ def final_mask_mod(
(is_valid, logical_q_idx, logical_kv_idx) = (
self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
)
# Apply mask modification only for valid indices
return torch.where(
is_valid,
self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx),
False,
)
return is_valid & self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx)

return final_mask_mod

Expand All @@ -469,7 +473,9 @@ def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
packed query sequences.
"""
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
request_lookup = _offsets_to_doc_ids_tensor(
self.query_start_loc_cpu, self.query_start_loc.device
)

def final_mask_mod(
b: torch.Tensor,
Expand Down Expand Up @@ -581,7 +587,9 @@ def get_transformed_score_mod(self) -> _score_mod_signature | None:
return None

# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
request_lookup = _offsets_to_doc_ids_tensor(
self.query_start_loc_cpu, self.query_start_loc.device
)
user_score_mod = self.score_mod

def transformed_score_mod(
Expand Down Expand Up @@ -726,7 +734,9 @@ def __post_init__(self):
assert self.prefix_kv_lens is None, "Not implemented yet."
assert self.suffix_kv_lens is None, "Not implemented yet."
# Create a lookup mapping from query indices -> request number
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
self.doc_ids = _offsets_to_doc_ids_tensor(
self.query_start_loc_cpu, self.query_start_loc.device
)
self.doc_ids = copy_to_persistent(self.persistent_doc_ids, self.doc_ids)
self.num_blocks = self.total_cache_tokens // self.block_size

Expand Down Expand Up @@ -807,6 +817,7 @@ def build(

max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
Expand Down Expand Up @@ -871,6 +882,7 @@ def build(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
Expand Down
37 changes: 18 additions & 19 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import async_tensor_h2d
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionMetadataBuilder,
Expand Down Expand Up @@ -270,16 +271,20 @@ def _build_chunk_metadata_tensors(
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens

num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
# Derive prefill context lengths from CPU data only.
# `seq_lens_cpu_upper_bound` is precise for prefill rows in all modes
# (including async spec decode), so this avoids the D2H sync that
# `compute_num_computed_tokens().cpu()` would force.
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
assert seq_lens_cpu is not None
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
prefill_query_lens_cpu = query_start_loc_p_cpu[1:] - query_start_loc_p_cpu[:-1]
num_computed_tokens_p_cpu = (
seq_lens_cpu[num_reqs - num_prefills : num_reqs] - prefill_query_lens_cpu
)

cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
chunk_size,
Expand All @@ -289,20 +294,14 @@ def _build_chunk_metadata_tensors(
)

device = common_attn_metadata.query_start_loc.device
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen,
device=device,
dtype=torch.int32,
)
seq_idx_p = torch.as_tensor(
seq_idx,
device=device,
dtype=torch.int32,
# Build on pinned CPU and upload non-blocking to avoid the synchronous
# H2D copy that `torch.as_tensor(list, device=cuda)` would force.
cu_chunk_seqlen_p = async_tensor_h2d(
cu_chunk_seqlen, dtype=torch.int32, device=device
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices,
device=device,
dtype=torch.int32,
seq_idx_p = async_tensor_h2d(seq_idx, dtype=torch.int32, device=device)
last_chunk_indices_p = async_tensor_h2d(
last_chunk_indices, dtype=torch.int32, device=device
)
return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p

Expand Down
58 changes: 52 additions & 6 deletions vllm/v1/attention/backends/tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ class TreeAttentionMetadata:
num_prefills: int = 0
num_decodes: int = 0

# Precomputed (on CPU in the builder) max_query_len and max_seq_len for
# the prefill-only and decode-only sub-batches. Used by the properties
# below to avoid a GPU->CPU sync via `.max().item()` on every forward.
max_query_len_prefill: int = 0
max_seq_len_prefill: int = 0
max_query_len_decode: int = 0
max_seq_len_decode: int = 0

tree_attn_bias: torch.Tensor | None = None

# Cached Prefill/decode metadata.
Expand All @@ -107,14 +115,13 @@ def prefill_metadata(self) -> "TreeAttentionMetadata | None":
return self._cached_prefill_metadata

q_start_loc = self.query_start_loc[self.num_decodes :]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[self.num_decodes :]
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_prefill_tokens,
max_query_len=int(q_seqlens.max().item()),
max_query_len=self.max_query_len_prefill,
query_start_loc=q_start_loc - q_start_loc[0],
max_seq_len=int(kv_seqlens.max().item()),
max_seq_len=self.max_seq_len_prefill,
seq_lens=kv_seqlens,
block_table=self.block_table[self.num_decodes :],
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
Expand All @@ -132,14 +139,13 @@ def decode_metadata(self) -> "TreeAttentionMetadata | None":
return self._cached_decode_metadata

q_start_loc = self.query_start_loc[: self.num_decodes + 1]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[: self.num_decodes]
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_decode_tokens,
max_query_len=int(q_seqlens.max().item()),
max_query_len=self.max_query_len_decode,
query_start_loc=q_start_loc,
max_seq_len=int(kv_seqlens.max().item()),
max_seq_len=self.max_seq_len_decode,
seq_lens=kv_seqlens,
block_table=self.block_table[: self.num_decodes],
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
Expand Down Expand Up @@ -199,6 +205,42 @@ def build(
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping

# Precompute prefill/decode sub-batch max_query_len / max_seq_len on
# CPU so the prefill_metadata / decode_metadata properties don't need
# a GPU->CPU sync via `.max().item()` on every forward.
# Prefer `seq_lens_cpu_upper_bound` over the (deprecated)
# `seq_lens_cpu` property: the upper bound is precise for prefill
# rows and optimistic-but-safe for decode rows (workspace sizing
# from `max()` is fine with an over-estimate), and avoids the
# `seq_lens.to("cpu")` sync the property would fall through to in
# async-spec-decode mode. The draft-attention path (eagle
# speculator) doesn't populate it; fall back to the batch-wide
# `max_seq_len` as a safe upper bound for both sub-batches.
q_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
if num_prefills > 0:
q_seqlens_p = torch.diff(q_start_loc_cpu[num_decodes:])
max_query_len_prefill = int(q_seqlens_p.max())
max_seq_len_prefill = (
int(seq_lens_cpu[num_decodes:].max())
if seq_lens_cpu is not None
else max_seq_len
)
else:
max_query_len_prefill = 0
max_seq_len_prefill = 0
if num_decodes > 0:
q_seqlens_d = torch.diff(q_start_loc_cpu[: num_decodes + 1])
max_query_len_decode = int(q_seqlens_d.max())
max_seq_len_decode = (
int(seq_lens_cpu[:num_decodes].max())
if seq_lens_cpu is not None
else max_seq_len
)
else:
max_query_len_decode = 0
max_seq_len_decode = 0

return TreeAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_prefill_tokens=num_prefill_tokens,
Expand All @@ -211,6 +253,10 @@ def build(
seq_lens=kv_seqlens,
block_table=block_table,
slot_mapping=slot_mapping,
max_query_len_prefill=max_query_len_prefill,
max_seq_len_prefill=max_seq_len_prefill,
max_query_len_decode=max_query_len_decode,
max_seq_len_decode=max_seq_len_decode,
tree_attn_bias=self.tree_attn_bias,
)

Expand Down
9 changes: 4 additions & 5 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.utils.torch_utils import async_tensor_h2d, is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
Expand Down Expand Up @@ -117,10 +117,9 @@ def compute_mm_prefix_range_tensor(
for r in range_lists:
padded_r = list(r) + [(0, 0)] * (max_ranges - len(r))
padded.append(padded_r)
# Create tensor with efficient H2D transfer
return torch.tensor(padded, dtype=torch.int32, device=device).view(
num_seqs, max_ranges, 2
)
# Build on pinned CPU memory so the H2D transfer is non-blocking.
padded = async_tensor_h2d(padded, dtype=torch.int32, device=device)
return padded.view(num_seqs, max_ranges, 2)


class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
Expand Down
Loading
Loading