Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
78704e6
[Core] Add VLLM_GPU_SYNC_CHECK env var
njhill Apr 21, 2026
f9a94b5
add prefill-only option
njhill Apr 22, 2026
51a6a03
skip initialization cases
njhill Apr 22, 2026
ce71b52
avoid various runtime gpu syncs
njhill Apr 22, 2026
3ed63ba
relocate util methods
njhill Apr 22, 2026
d4c251c
fix cpu case
njhill Apr 22, 2026
3735c1a
fix torch compile
njhill Apr 22, 2026
151eb52
remove prefill-only
njhill Apr 22, 2026
c631e7c
suppress a couple of other places
njhill Apr 22, 2026
3ad45a9
fix sync in mm triton attn
njhill Apr 22, 2026
8c77f74
granular suppression for non-async-scheduling case
njhill Apr 22, 2026
5e03d53
minor
njhill Apr 22, 2026
d530b72
avoid sync in flex attention
njhill Apr 22, 2026
8010fa7
fix mamba
njhill Apr 22, 2026
f19e5ba
fix sync in mrv2 penalties
njhill Apr 22, 2026
595dee8
handle qwen2_audio
njhill Apr 22, 2026
7511af6
terratorch
njhill Apr 23, 2026
708af9c
fix tokwise and seqwise poolers
njhill Apr 23, 2026
cb156f6
fix bert
njhill Apr 23, 2026
fed36ab
suppress lora cpu sync
njhill Apr 23, 2026
9c0b8c0
fix tree attn
njhill Apr 23, 2026
60ff592
simplify
njhill Apr 23, 2026
3a5e14e
flex attn
njhill Apr 23, 2026
bc8d302
ultravox
njhill Apr 23, 2026
23958e2
qwen2.5
njhill Apr 23, 2026
ddfce27
add skip_first arg to with_gpu_sync_check decorator
njhill Apr 23, 2026
dbec8e3
dp_utils
njhill Apr 23, 2026
60eb033
phi3
njhill Apr 23, 2026
3e7c50e
mamba
njhill Apr 23, 2026
8a39825
more flex
njhill Apr 23, 2026
390be12
example connector
njhill Apr 23, 2026
ac82480
internvl
njhill Apr 23, 2026
a2f0881
example logits processor
njhill Apr 23, 2026
0d48de2
only enable post-warmup
njhill Apr 23, 2026
e8b6875
glm4.1, qwen2-vl, qwen3-omni
njhill Apr 23, 2026
481e1ee
inductor lazy init
njhill Apr 23, 2026
fa66c6d
mamba pfx cache test stub
njhill Apr 23, 2026
fd0d1c8
more tree_attn
njhill Apr 23, 2026
52975f5
lora ops
njhill Apr 23, 2026
00a94fe
async ngram
njhill Apr 23, 2026
7aef85b
transformers mm
njhill Apr 23, 2026
cd2110e
mamba mixer
njhill Apr 23, 2026
cce472e
flashinfer
njhill Apr 23, 2026
5f0c86f
qwen2.5-vl
njhill Apr 23, 2026
6da466a
fix tree_attn
njhill Apr 24, 2026
855d3c2
tokwise pooler
njhill Apr 24, 2026
8e722ae
idefics3
njhill Apr 24, 2026
6fa40cf
more mamba_attn
njhill Apr 24, 2026
6a1fd14
fix preemptive lazy_init
njhill Apr 24, 2026
69be8d1
gemma3 mm
njhill Apr 24, 2026
7a6bba8
minor formatting
njhill Apr 24, 2026
a9e33b1
move inductor lazy init to util method
njhill Apr 24, 2026
df26b02
fast prefill
njhill Apr 24, 2026
e5b45df
lora load adapter
njhill Apr 24, 2026
ccb7c03
idefics2
njhill Apr 24, 2026
e509ea8
phi4mm_audio
njhill Apr 24, 2026
24afeb6
temp
njhill Apr 24, 2026
bd27f57
temp2
njhill Apr 24, 2026
fb51bda
temp3
njhill Apr 24, 2026
e44ac81
fix custom lp
njhill Apr 24, 2026
ceda006
temp4
njhill Apr 24, 2026
a7f931f
temp5
njhill Apr 24, 2026
9822304
h2d util
njhill Apr 27, 2026
93ac29e
use async_tensor_h2d utility function
njhill Apr 27, 2026
be7b548
avoid circular import
njhill Apr 27, 2026
4805d7b
typo
njhill Apr 27, 2026
ddc75a8
qwen recompute_mrope_positions; qwen3_vl updates
njhill Apr 27, 2026
610ff40
switch gpu_sync_allowed count to first_only bool
njhill Apr 27, 2026
b3007c0
remove now-redundant guards in grouped_topk_router.py
njhill Apr 27, 2026
bcbbb93
qwen3_asr
njhill Apr 28, 2026
a926b48
post-rebase fixups
njhill Apr 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,9 @@ ENV HF_XET_HIGH_PERFORMANCE 1
# increase timeout for hf downloads (for testing)
ENV HF_HUB_DOWNLOAD_TIMEOUT 60

# Catch GPU<->CPU syncs in execute_model/sample_tokens
ENV VLLM_GPU_SYNC_CHECK=error

# Copy in the v1 package for testing (it isn't distributed yet)
COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1

Expand Down
3 changes: 3 additions & 0 deletions docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,9 @@ ENV HF_XET_HIGH_PERFORMANCE=1
# increase timeout for hf downloads (for testing)
ENV HF_HUB_DOWNLOAD_TIMEOUT 60

# Catch GPU<->CPU syncs in execute_model/sample_tokens
ENV VLLM_GPU_SYNC_CHECK=error

# install audio decode package `torchcodec` from source (required due to
# ROCm and torch version mismatch) for tests with datasets package
COPY tools/install_torchcodec_rocm.sh /tmp/install_torchcodec.sh
Expand Down
24 changes: 14 additions & 10 deletions tests/v1/e2e/general/test_mamba_prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,14 @@ def fake_sample_fn(
first_token_id_index = num_computed_tokens + 1
if spec_decode_metadata is None:
return SamplerOutput(
# Build on pinned CPU + non_blocking H2D rather than
# `torch.tensor(..., device=DEVICE_TYPE)` which would force
# a synchronous copy and trip the sync check.
sampled_token_ids=torch.tensor(
[[prompt_token_ids[first_token_id_index]]],
device=DEVICE_TYPE,
pin_memory=True,
dtype=torch.int32,
),
).to(DEVICE_TYPE, non_blocking=True),
logprobs_tensors=None,
)
accepted_tokens = prompt_token_ids[
Expand All @@ -86,9 +89,9 @@ def fake_sample_fn(
return SamplerOutput(
sampled_token_ids=torch.tensor(
[sampled_token_ids],
device=DEVICE_TYPE,
pin_memory=True,
dtype=torch.int32,
),
).to(DEVICE_TYPE, non_blocking=True),
logprobs_tensors=None,
)

Expand Down Expand Up @@ -126,29 +129,30 @@ def fake_propose_draft_token_ids_fn(
]
]

# Build on pinned CPU + non-blocking upload to avoid synchronous H2D.
next_token_ids = torch.tensor(
prompt_token_ids[
first_token_id_index - 1 : first_token_id_index
- 1
+ num_accepted_tokens
],
device=DEVICE_TYPE,
dtype=torch.int32,
)
pin_memory=True,
).to(DEVICE_TYPE, non_blocking=True)

valid_sampled_tokens_count = torch.tensor(
[num_accepted_tokens],
device=DEVICE_TYPE,
dtype=torch.int32,
)
pin_memory=True,
).to(DEVICE_TYPE, non_blocking=True)

self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)

return torch.tensor(
proposed_draft_token_ids,
device=DEVICE_TYPE,
dtype=torch.int32,
)
pin_memory=True,
).to(DEVICE_TYPE, non_blocking=True)

return fake_propose_draft_token_ids_fn

Expand Down
24 changes: 16 additions & 8 deletions tests/v1/logits_processors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,25 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.req_info:
return logits

# Save target values before modification
# Save target values before modification. Build on pinned CPU then
# non-blocking upload to avoid a synchronous H2D copy.
cols = torch.tensor(
list(self.req_info.values()), dtype=torch.long, device=logits.device
)
list(self.req_info.values()), dtype=torch.long, pin_memory=True
).to(logits.device, non_blocking=True)
rows = torch.tensor(
list(self.req_info.keys()), dtype=torch.long, device=logits.device
)
list(self.req_info.keys()), dtype=torch.long, pin_memory=True
).to(logits.device, non_blocking=True)
values_to_keep = logits[rows, cols].clone()

# Mask all but target tokens
logits[rows] = float("-inf")
# Mask all but target tokens. Use an on-device fill tensor so the
# scatter doesn't force a synchronizing scalar H2D.
fill = torch.full(
(rows.numel(), logits.size(-1)),
float("-inf"),
dtype=logits.dtype,
device=logits.device,
)
logits[rows] = fill
logits[rows, cols] = values_to_keep

return logits
Expand Down Expand Up @@ -142,7 +150,7 @@ def __call__(
output_ids: list[int],
logits: torch.Tensor,
) -> torch.Tensor:
val_to_keep = logits[self.target_token].item()
val_to_keep = logits[self.target_token].clone()
logits[:] = float("-inf")
logits[self.target_token] = val_to_keep
return logits
Expand Down
28 changes: 28 additions & 0 deletions vllm/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,34 @@ def set_functorch_config() -> None:
setattr(torch._functorch.config, k, v)


def trigger_inductor_lazy_init(device: torch.device | None = None) -> None:
"""Eagerly trigger inductor's once-per-process lazy inits (SFDP pattern
matcher, pad_mm, misc patterns).

These normally fire on the first torch.compile invocation and include
CUDA syncs. If warmup hits the on-disk compile cache, no compile actually
runs so these never fire during warmup, and they'd blow up on the first
real-request cache miss once the sync-check gate is on.

Private torch API; best-effort. Newer torch versions take an
`input_device` argument and cache per-device, so pass the current CUDA
device to ensure the cache key matches later compile calls.
"""
try:
import inspect

from torch._inductor.fx_passes.joint_graph import (
lazy_init as _inductor_lazy_init,
)

if inspect.signature(_inductor_lazy_init).parameters:
_inductor_lazy_init(device)
else:
_inductor_lazy_init()
except Exception as e: # noqa: BLE001
logger.info("Skipping inductor lazy_init pre-trigger: %s", e)


class EagerAdaptor(CompilerInterface):
name = "eager"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def inject_kv_into_layer(
[num_tokens].
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
# `slot_mapping` is built CPU-side in `ReqMeta.make_meta`; upload
# non-blocking so the advanced-index ops below don't force a
# synchronous H2D of the index tensor.
slot_mapping = slot_mapping.to(dst_kv_cache_layer.device, non_blocking=True)
if isinstance(attn_metadata, MLACommonMetadata):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
Expand Down Expand Up @@ -188,7 +192,8 @@ def inject_kv_into_layer(
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
kv_cache_cpu = safetensors.torch.load_file(filename)["kv_cache"]
kv_cache = kv_cache_cpu.to("cuda", non_blocking=True)
if isinstance(attn_metadata, dict):
inject_kv_into_layer(
kv_cache_layer,
Expand Down Expand Up @@ -235,6 +240,10 @@ def extract_kv_from_layer(
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
# `slot_mapping` is built CPU-side in `ReqMeta.make_meta`; upload
# non-blocking so the advanced-index ops below don't force a
# synchronous H2D of the index tensor.
slot_mapping = slot_mapping.to(layer.device, non_blocking=True)
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
Expand All @@ -245,6 +254,8 @@ def extract_kv_from_layer(
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]

from vllm.utils.gpu_sync_debug import gpu_sync_allowed

connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, ExampleConnectorMetadata)
for request in connector_metadata.requests:
Expand All @@ -253,7 +264,9 @@ def extract_kv_from_layer(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
tensors = {"kv_cache": kv_cache.detach().cpu()}
# `.cpu()` is an unavoidable D2H to serialize the cache.
with gpu_sync_allowed():
tensors = {"kv_cache": kv_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)

def wait_for_save(self):
Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
VLLM_MAIN_CUDA_VERSION: str = "13.0"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
VLLM_BATCH_INVARIANT: bool = False
VLLM_GPU_SYNC_CHECK: Literal["warn", "error"] | None = None
MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False
Expand Down Expand Up @@ -523,6 +524,13 @@ def _get_or_set_default() -> str:
# Enable batch-invariant mode: deterministic results regardless of
# batch composition. Requires NVIDIA GPU with compute capability >= 9.0.
"VLLM_BATCH_INVARIANT": lambda: bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0"))),
# If set, enable PyTorch's GPU<->CPU synchronization debug mode around
# the worker's `execute_model` and `sample_tokens` calls. Valid values
# are "warn" (print a warning on each sync) or "error" (raise on sync).
# Unset disables the check. See `torch.cuda.set_sync_debug_mode`.
"VLLM_GPU_SYNC_CHECK": env_with_choices(
"VLLM_GPU_SYNC_CHECK", None, ["warn", "error"], case_sensitive=False
),
# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
Expand Down
35 changes: 26 additions & 9 deletions vllm/lora/ops/triton_ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import async_tensor_h2d

logger = init_logger(__name__)
is_batch_invariant = envs.VLLM_BATCH_INVARIANT
Expand Down Expand Up @@ -49,7 +50,11 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
lora_strides_d1.append(lora_a_weight.stride(1))
lora_strides_d2.append(lora_a_weight.stride(2))
if len(lora_a_weights) > 1:
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
# Pinned CPU + non_blocking H2D avoids the synchronous copy that
# `torch.tensor(list, device=cuda)` would otherwise force.
lora_ptr_tensor = async_tensor_h2d(
tensor_ptrs, dtype=torch.uint64, device=device
)
else:
lora_ptr_tensor = lora_a_weights[0]

Expand Down Expand Up @@ -106,10 +111,13 @@ def _get_lora_b_ptr(
hidden_sizes.append(lora_b_weight.size(1))

if len(lora_weights) > 1:
# note these are device tensors
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
slice_start_tensor = torch.tensor(
slice_offset_lst, device=device, dtype=torch.uint64
# note these are device tensors. Pinned CPU + non_blocking H2D
# avoids the sync that `torch.tensor(list, device=cuda)` forces.
lora_ptr_tensor = async_tensor_h2d(
tensor_ptrs, dtype=torch.uint64, device=device
)
slice_start_tensor = async_tensor_h2d(
slice_offset_lst, dtype=torch.uint64, device=device
)
else:
slice_start_tensor = slice_offset_lst[0]
Expand All @@ -129,10 +137,19 @@ def _get_lora_b_ptr(
same_stride = True

else:
lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device)
lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device)
lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device)
hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device)
# Pinned CPU + non_blocking H2D to avoid blocking copies.
lora_strides_d0_tensor = async_tensor_h2d(
lora_strides_d0, dtype=torch.int64, device=device
)
lora_strides_d1_tensor = async_tensor_h2d(
lora_strides_d1, dtype=torch.int64, device=device
)
lora_strides_d2_tensor = async_tensor_h2d(
lora_strides_d2, dtype=torch.int64, device=device
)
hidden_sizes_tensor = async_tensor_h2d(
hidden_sizes, dtype=torch.int64, device=device
)
same_stride = False
# MAX_N is the maximum hidden size among all the lora_b weights
MAX_N = max(hidden_sizes)
Expand Down
16 changes: 13 additions & 3 deletions vllm/lora/punica_wrapper/punica_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.lora.layers import LoRAMapping
from vllm.lora.utils import get_captured_lora_counts
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.gpu_sync_debug import gpu_sync_allowed
from vllm.utils.math_utils import round_up

if HAS_TRITON:
Expand Down Expand Up @@ -83,9 +84,18 @@ def update_metadata(
self.is_prefill = mapping.is_prefill
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)

# Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
# This method has two unavoidable GPU->CPU syncs given the current
# design: (1) the `torch.all(... == -1)` no-lora check below, and
# (2) `torch.unique(...)` + reading `lora_ids.size(0)` as a Python
# int further down. Both ultimately stem from needing facts about
# `token_lora_mapping`'s contents on the host (is everything -1?
# how many distinct loras?). TODO: compute these on CPU upstream
# in `convert_mapping` where the mapping is still a Python list,
# then pass the results in.
with gpu_sync_allowed():
# Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)

def add_shrink(
self,
Expand Down
6 changes: 4 additions & 2 deletions vllm/lora/punica_wrapper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import torch

from vllm.utils.torch_utils import async_tensor_h2d

if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
Expand Down Expand Up @@ -110,8 +112,8 @@ def convert_mapping(
embedding_indices,
]

indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(
indices = async_tensor_h2d(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = async_tensor_h2d(
prompt_mapping, dtype=torch.long, device=device
)
embeddings_indices = torch.stack(
Expand Down
Loading
Loading