Skip to content
2 changes: 1 addition & 1 deletion tests/v1/attention/test_chunked_local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):

# Convert to numpy for easier comparison
actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy())
actual_k_seqlens = result.seq_lens_cpu.numpy()
actual_k_seqlens = result.seq_lens.cpu().numpy()

# Check that all query lengths are less than or equal to attn_chunk_size
assert all(q_len <= attn_chunk_size for q_len in actual_q_seqlens)
Expand Down
12 changes: 1 addition & 11 deletions tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,7 @@ def create_common_attn_metadata(

# Create sequence lengths
seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device)
seq_lens_cpu = seq_lens.cpu()
max_seq_len = int(seq_lens_cpu.max())

# Create computed tokens (context length for each sequence)
context_lens = [
batch_spec.seq_lens[i] - batch_spec.query_lens[i]
for i in range(batch_spec.batch_size)
]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
max_seq_len = int(seq_lens.max().item())

# Create block table and slot mapping
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
Expand Down Expand Up @@ -106,8 +98,6 @@ def create_common_attn_metadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
Expand Down
91 changes: 63 additions & 28 deletions tests/v1/e2e/test_async_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,85 @@
@pytest.fixture
def sync_tracker():
"""
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
lazy init syncs. Prints stack traces immediately when syncs occur.
Fixture that patches CommonAttentionMetadata.seq_lens to detect .cpu() calls.
This tracks when code accesses seq_lens and converts it to CPU, which causes
a GPU-CPU sync that breaks async scheduling.
"""
from vllm.v1.attention.backend import CommonAttentionMetadata

# Shared counter for cross-process communication (inherited by fork)
sync_count = multiprocessing.Value("i", 0)

# Save original property
original_prop = CommonAttentionMetadata.seq_lens_cpu
original_fget = original_prop.fget

# Create tracking wrapper
def tracking_seq_lens_cpu(self):
if self._seq_lens_cpu is None:
# Increment counter
with sync_count.get_lock():
sync_count.value += 1
count = sync_count.value
# Print stack trace immediately (shows in subprocess output)
print(f"\n{'=' * 60}", file=sys.stderr)
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
print(f"{'=' * 60}", file=sys.stderr)
traceback.print_stack(file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
sys.stderr.flush()
return original_fget(self)

# Apply patch
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
original_cpu = torch.Tensor.cpu

# Create a wrapper that tracks .cpu() calls on seq_lens tensors
tracked_tensors: set = set()

original_getattribute = CommonAttentionMetadata.__getattribute__

def tracking_getattribute(self, name):
value = original_getattribute(self, name)
if name == "seq_lens" and isinstance(value, torch.Tensor):
# Mark this tensor as one we want to track
tracked_tensors.add(id(value))
return value

# Backends that intentionally call .cpu() for their operations
ALLOWED_BACKENDS = ["flashinfer.py", "mla/indexer.py", "mla/flashmla_sparse.py"]

def tracking_cpu(tensor_self, *args, **kwargs):
if tensor_self.is_cuda and id(tensor_self) in tracked_tensors:
# Check if this is from an allowed backend
stack = traceback.format_stack()
stack_str = "".join(stack)
is_allowed = any(backend in stack_str for backend in ALLOWED_BACKENDS)
if not is_allowed:
with sync_count.get_lock():
sync_count.value += 1
count = sync_count.value
print(f"\n{'=' * 60}", file=sys.stderr)
print(
f"SYNC #{count}: .cpu() called on CommonAttentionMetadata.seq_lens",
file=sys.stderr,
)
print(
f"Shape: {tensor_self.shape}, dtype: {tensor_self.dtype}",
file=sys.stderr,
)
print(f"{'=' * 60}", file=sys.stderr)
traceback.print_stack(file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
sys.stderr.flush()
return original_cpu(tensor_self, *args, **kwargs)

# Apply patches
CommonAttentionMetadata.__getattribute__ = tracking_getattribute
torch.Tensor.cpu = tracking_cpu

class SyncTracker:
@property
def count(self) -> int:
return sync_count.value

def start_tracking(self):
"""Start tracking syncs from this point. Call after model loading."""
with sync_count.get_lock():
sync_count.value = 0
tracked_tensors.clear()

def assert_no_sync(self, msg: str = ""):
count = sync_count.value
assert count == 0, (
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
f"{count} times. See stack traces above. {msg}"
f"Unexpected GPU-CPU sync: .cpu() called on "
f"CommonAttentionMetadata.seq_lens {count} times. "
f"See stack traces above. {msg}"
)

yield SyncTracker()

# Restore original property
CommonAttentionMetadata.seq_lens_cpu = original_prop
# Restore original methods
CommonAttentionMetadata.__getattribute__ = original_getattribute
torch.Tensor.cpu = original_cpu
torch._dynamo.reset()


Expand Down Expand Up @@ -116,6 +148,9 @@ def test_no_sync_with_spec_decode(
async_scheduling=True,
)

# Start tracking after model loading - we only care about syncs during generation
sync_tracker.start_tracking()

outputs = llm.generate(
["Hello, my name is"],
SamplingParams(temperature=0, max_tokens=10),
Expand Down
3 changes: 2 additions & 1 deletion tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,15 @@ def test_prepare_next_token_ids():
block_size=16,
device=device,
)
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()

expected_valid_sampled_tokens_count = torch.tensor(
[2, 5, 0, 0], dtype=torch.int32, device=device
)

next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded(
common_attn_metadata,
seq_lens_cpu,
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,
Expand Down
4 changes: 0 additions & 4 deletions tests/v1/spec_decode/test_tree_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,12 @@ def forward_attention(
query_start_loc = q_len * torch.arange(
batch_size + 1, device=q.device, dtype=torch.int32
)
query_lens = torch.diff(query_start_loc)
seq_lens = torch.full(
(batch_size,),
seqlen_k,
device=q.device,
dtype=torch.int32,
)
context_lens = seq_lens - query_lens
max_seq_len = int(seq_lens.max())
max_query_len = q_len
num_actual_tokens = query_start_loc[-1]
Expand Down Expand Up @@ -89,8 +87,6 @@ def forward_attention(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
_seq_lens_cpu=seq_lens.cpu(),
_num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
Expand Down
17 changes: 8 additions & 9 deletions vllm/model_executor/layers/attention/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,25 @@ def build(
new_metadata.causal = False
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill)
num_cache_decodes = (
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
query_lens_cpu = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
if num_cache_decodes > 0:
# Any computed tokens indicated decode step>1 (no chunked prefill)
is_decode = seq_lens_cpu > query_lens_cpu
if torch.any(is_decode):
# CrossAttn KV cache has already been populated on first decoder step,
# skip slot_mapping calculation for requests that do not need
# reshape_and_cache.
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
new_metadata.encoder_seq_lens_cpu = np.where(
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
is_decode, 0, new_metadata.encoder_seq_lens_cpu
)

# seq_lens is provided by model runner: initial encoder input length is
# needed here to know how many tokens to attend to from the cached
# cross-attention KV cache.
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
new_metadata._seq_lens_cpu = torch.from_numpy(
common_attn_metadata.encoder_seq_lens_cpu
)

# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
slot_mapping = _get_cross_slot_mapping(
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/whisper_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,6 @@ def build(
new_common_attn_metadata.query_start_loc *= block_pool_size
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
new_common_attn_metadata.seq_lens *= block_pool_size
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
new_common_attn_metadata.num_actual_tokens *= block_pool_size
new_common_attn_metadata.max_query_len *= block_pool_size
new_common_attn_metadata.max_seq_len *= block_pool_size
Expand Down
42 changes: 9 additions & 33 deletions vllm/v1/attention/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,6 @@ class CommonAttentionMetadata:
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""

# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None

_num_computed_tokens_cache: torch.Tensor | None = None

def batch_size(self) -> int:
return self.seq_lens.shape[0]

Expand All @@ -342,6 +336,12 @@ def naive_query_lens(self) -> torch.Tensor:
def replace(self, **kwargs) -> "CommonAttentionMetadata":
return replace(self, **kwargs)

# WARNING: Deprecated fields. Will be removed in a future release
# Keep seq_lens_cpu for now to avoid performance regressions with FlashInfer on
# sm120 machines, will remove once FA4 is performant enough on sm120.
# see: https://github.com/vllm-project/vllm/pull/33771
_seq_lens_cpu: torch.Tensor | None = None

@property
@deprecated(
"""
Expand All @@ -355,29 +355,11 @@ def seq_lens_cpu(self) -> torch.Tensor:
self._seq_lens_cpu = self.seq_lens.to("cpu")
return self._seq_lens_cpu

@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
async scheduling. If a CPU copy is needed, it can be derived from
query_start_loc_cpu and seq_lens.
Will be removed in a future release, please migrate as soon as possible.
"""
)
def num_computed_tokens_cpu(self) -> torch.Tensor:
if self._num_computed_tokens_cpu is None:
query_seq_lens = (
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
)
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
return self._num_computed_tokens_cpu

def compute_num_computed_tokens(self) -> torch.Tensor:
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
if self._num_computed_tokens_cache is None:
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
self._num_computed_tokens_cache = self.seq_lens - query_lens
return self._num_computed_tokens_cache
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
num_computed_tokens = self.seq_lens - query_lens
return num_computed_tokens

# TODO(lucas): remove once we have FULL-CG spec-decode support
def unpadded(
Expand All @@ -388,12 +370,6 @@ def unpadded(
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
if self._seq_lens_cpu is not None
else None,
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
if self._num_computed_tokens_cpu is not None
else None,
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,10 @@ def build(

prefill_metadata = None
if num_prefills > 0:
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()

chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu[num_decodes:],
seq_lens_cpu[num_decodes:],
self.max_prefill_buffer_size,
request_offset=num_decodes,
)
Expand All @@ -299,7 +301,7 @@ def build(
reqs_start,
reqs_end,
query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
seq_lens_cpu,
common_attn_metadata.block_table_tensor,
)
for reqs_start, reqs_end in chunk_seq_ids
Expand Down
9 changes: 2 additions & 7 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def make_local_attention_virtual_batches(
block_size: int = 0,
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
seq_lens_np = common_attn_metadata.seq_lens.cpu().numpy()
block_table = common_attn_metadata.block_table_tensor
device = common_attn_metadata.query_start_loc.device

Expand Down Expand Up @@ -285,7 +285,6 @@ def make_local_attention_virtual_batches(
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
num_computed_tokens_local = seqlens_k_local - seqlens_q_local

k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
Expand Down Expand Up @@ -354,8 +353,6 @@ def make_local_attention_virtual_batches(
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
), make_block_table


Expand Down Expand Up @@ -412,8 +409,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
)
return common_attn_metadata

Expand Down Expand Up @@ -443,7 +438,7 @@ def split_decodes_prefills_and_extends(
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens_cpu
seq_lens = common_attn_metadata.seq_lens.cpu()

if max_query_len <= decode_threshold:
return num_reqs, 0, 0, num_tokens, 0, 0
Expand Down
Loading