diff --git a/docker/Dockerfile b/docker/Dockerfile index a6b291407713..273d21e69b6b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -747,6 +747,9 @@ ENV HF_XET_HIGH_PERFORMANCE 1 # increase timeout for hf downloads (for testing) ENV HF_HUB_DOWNLOAD_TIMEOUT 60 +# Catch GPU<->CPU syncs in execute_model/sample_tokens +ENV VLLM_GPU_SYNC_CHECK=error + # Copy in the v1 package for testing (it isn't distributed yet) COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1 diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index f416c8136118..fe5ee01967f6 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -427,6 +427,9 @@ ENV HF_XET_HIGH_PERFORMANCE=1 # increase timeout for hf downloads (for testing) ENV HF_HUB_DOWNLOAD_TIMEOUT 60 +# Catch GPU<->CPU syncs in execute_model/sample_tokens +ENV VLLM_GPU_SYNC_CHECK=error + # install audio decode package `torchcodec` from source (required due to # ROCm and torch version mismatch) for tests with datasets package COPY tools/install_torchcodec_rocm.sh /tmp/install_torchcodec.sh diff --git a/tests/v1/e2e/general/test_mamba_prefix_cache.py b/tests/v1/e2e/general/test_mamba_prefix_cache.py index 8b9f7bb6c5ad..2bc7971ea8a8 100644 --- a/tests/v1/e2e/general/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/general/test_mamba_prefix_cache.py @@ -71,11 +71,14 @@ def fake_sample_fn( first_token_id_index = num_computed_tokens + 1 if spec_decode_metadata is None: return SamplerOutput( + # Build on pinned CPU + non_blocking H2D rather than + # `torch.tensor(..., device=DEVICE_TYPE)` which would force + # a synchronous copy and trip the sync check. sampled_token_ids=torch.tensor( [[prompt_token_ids[first_token_id_index]]], - device=DEVICE_TYPE, + pin_memory=True, dtype=torch.int32, - ), + ).to(DEVICE_TYPE, non_blocking=True), logprobs_tensors=None, ) accepted_tokens = prompt_token_ids[ @@ -86,9 +89,9 @@ def fake_sample_fn( return SamplerOutput( sampled_token_ids=torch.tensor( [sampled_token_ids], - device=DEVICE_TYPE, + pin_memory=True, dtype=torch.int32, - ), + ).to(DEVICE_TYPE, non_blocking=True), logprobs_tensors=None, ) @@ -126,29 +129,30 @@ def fake_propose_draft_token_ids_fn( ] ] + # Build on pinned CPU + non-blocking upload to avoid synchronous H2D. next_token_ids = torch.tensor( prompt_token_ids[ first_token_id_index - 1 : first_token_id_index - 1 + num_accepted_tokens ], - device=DEVICE_TYPE, dtype=torch.int32, - ) + pin_memory=True, + ).to(DEVICE_TYPE, non_blocking=True) valid_sampled_tokens_count = torch.tensor( [num_accepted_tokens], - device=DEVICE_TYPE, dtype=torch.int32, - ) + pin_memory=True, + ).to(DEVICE_TYPE, non_blocking=True) self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count) return torch.tensor( proposed_draft_token_ids, - device=DEVICE_TYPE, dtype=torch.int32, - ) + pin_memory=True, + ).to(DEVICE_TYPE, non_blocking=True) return fake_propose_draft_token_ids_fn diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index e54da72e5e2e..f8d14508a99a 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -86,17 +86,25 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_info: return logits - # Save target values before modification + # Save target values before modification. Build on pinned CPU then + # non-blocking upload to avoid a synchronous H2D copy. cols = torch.tensor( - list(self.req_info.values()), dtype=torch.long, device=logits.device - ) + list(self.req_info.values()), dtype=torch.long, pin_memory=True + ).to(logits.device, non_blocking=True) rows = torch.tensor( - list(self.req_info.keys()), dtype=torch.long, device=logits.device - ) + list(self.req_info.keys()), dtype=torch.long, pin_memory=True + ).to(logits.device, non_blocking=True) values_to_keep = logits[rows, cols].clone() - # Mask all but target tokens - logits[rows] = float("-inf") + # Mask all but target tokens. Use an on-device fill tensor so the + # scatter doesn't force a synchronizing scalar H2D. + fill = torch.full( + (rows.numel(), logits.size(-1)), + float("-inf"), + dtype=logits.dtype, + device=logits.device, + ) + logits[rows] = fill logits[rows, cols] = values_to_keep return logits @@ -142,7 +150,7 @@ def __call__( output_ids: list[int], logits: torch.Tensor, ) -> torch.Tensor: - val_to_keep = logits[self.target_token].item() + val_to_keep = logits[self.target_token].clone() logits[:] = float("-inf") logits[self.target_token] = val_to_keep return logits diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 933554faa280..615510437ef2 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -761,6 +761,34 @@ def set_functorch_config() -> None: setattr(torch._functorch.config, k, v) +def trigger_inductor_lazy_init(device: torch.device | None = None) -> None: + """Eagerly trigger inductor's once-per-process lazy inits (SFDP pattern + matcher, pad_mm, misc patterns). + + These normally fire on the first torch.compile invocation and include + CUDA syncs. If warmup hits the on-disk compile cache, no compile actually + runs so these never fire during warmup, and they'd blow up on the first + real-request cache miss once the sync-check gate is on. + + Private torch API; best-effort. Newer torch versions take an + `input_device` argument and cache per-device, so pass the current CUDA + device to ensure the cache key matches later compile calls. + """ + try: + import inspect + + from torch._inductor.fx_passes.joint_graph import ( + lazy_init as _inductor_lazy_init, + ) + + if inspect.signature(_inductor_lazy_init).parameters: + _inductor_lazy_init(device) + else: + _inductor_lazy_init() + except Exception as e: # noqa: BLE001 + logger.info("Skipping inductor lazy_init pre-trigger: %s", e) + + class EagerAdaptor(CompilerInterface): name = "eager" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py index 24e156561dfb..87a9111018ec 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py @@ -139,6 +139,10 @@ def inject_kv_into_layer( [num_tokens]. """ dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + # `slot_mapping` is built CPU-side in `ReqMeta.make_meta`; upload + # non-blocking so the advanced-index ops below don't force a + # synchronous H2D of the index tensor. + slot_mapping = slot_mapping.to(dst_kv_cache_layer.device, non_blocking=True) if isinstance(attn_metadata, MLACommonMetadata): num_pages = dst_kv_cache_layer_shape[0] page_size = dst_kv_cache_layer_shape[1] @@ -188,7 +192,8 @@ def inject_kv_into_layer( filename = self._generate_filename_debug( layer_name, request.token_ids, request.mm_hashes ) - kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda() + kv_cache_cpu = safetensors.torch.load_file(filename)["kv_cache"] + kv_cache = kv_cache_cpu.to("cuda", non_blocking=True) if isinstance(attn_metadata, dict): inject_kv_into_layer( kv_cache_layer, @@ -235,6 +240,10 @@ def extract_kv_from_layer( Assume the shape of the layer is (2, num_pages, page_size, xxx) if MLA is not used, and (num_pages, page_size, xxx) otherwise. """ + # `slot_mapping` is built CPU-side in `ReqMeta.make_meta`; upload + # non-blocking so the advanced-index ops below don't force a + # synchronous H2D of the index tensor. + slot_mapping = slot_mapping.to(layer.device, non_blocking=True) if isinstance(attn_metadata, MLACommonMetadata): num_pages, page_size = layer.shape[0], layer.shape[1] return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] @@ -245,6 +254,8 @@ def extract_kv_from_layer( num_pages, page_size = layer.shape[1], layer.shape[2] return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...] + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, ExampleConnectorMetadata) for request in connector_metadata.requests: @@ -253,7 +264,9 @@ def extract_kv_from_layer( layer_name, request.token_ids, request.mm_hashes ) kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) - tensors = {"kv_cache": kv_cache.detach().cpu()} + # `.cpu()` is an unavoidable D2H to serialize the cache. + with gpu_sync_allowed(): + tensors = {"kv_cache": kv_cache.detach().cpu()} safetensors.torch.save_file(tensors, filename) def wait_for_save(self): diff --git a/vllm/envs.py b/vllm/envs.py index d55732cb6a8d..8cfe3864a50c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -82,6 +82,7 @@ VLLM_MAIN_CUDA_VERSION: str = "13.0" VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest" VLLM_BATCH_INVARIANT: bool = False + VLLM_GPU_SYNC_CHECK: Literal["warn", "error"] | None = None MAX_JOBS: str | None = None NVCC_THREADS: str | None = None VLLM_USE_PRECOMPILED: bool = False @@ -523,6 +524,13 @@ def _get_or_set_default() -> str: # Enable batch-invariant mode: deterministic results regardless of # batch composition. Requires NVIDIA GPU with compute capability >= 9.0. "VLLM_BATCH_INVARIANT": lambda: bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0"))), + # If set, enable PyTorch's GPU<->CPU synchronization debug mode around + # the worker's `execute_model` and `sample_tokens` calls. Valid values + # are "warn" (print a warning on each sync) or "error" (raise on sync). + # Unset disables the check. See `torch.cuda.set_sync_debug_mode`. + "VLLM_GPU_SYNC_CHECK": env_with_choices( + "VLLM_GPU_SYNC_CHECK", None, ["warn", "error"], case_sensitive=False + ), # Maximum number of compilation jobs to run in parallel. # By default this is the number of CPUs "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index af0a3157f02a..ec2faa767d2f 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.math_utils import next_power_of_2 +from vllm.utils.torch_utils import async_tensor_h2d logger = init_logger(__name__) is_batch_invariant = envs.VLLM_BATCH_INVARIANT @@ -49,7 +50,11 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): lora_strides_d1.append(lora_a_weight.stride(1)) lora_strides_d2.append(lora_a_weight.stride(2)) if len(lora_a_weights) > 1: - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) + # Pinned CPU + non_blocking H2D avoids the synchronous copy that + # `torch.tensor(list, device=cuda)` would otherwise force. + lora_ptr_tensor = async_tensor_h2d( + tensor_ptrs, dtype=torch.uint64, device=device + ) else: lora_ptr_tensor = lora_a_weights[0] @@ -106,10 +111,13 @@ def _get_lora_b_ptr( hidden_sizes.append(lora_b_weight.size(1)) if len(lora_weights) > 1: - # note these are device tensors - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) - slice_start_tensor = torch.tensor( - slice_offset_lst, device=device, dtype=torch.uint64 + # note these are device tensors. Pinned CPU + non_blocking H2D + # avoids the sync that `torch.tensor(list, device=cuda)` forces. + lora_ptr_tensor = async_tensor_h2d( + tensor_ptrs, dtype=torch.uint64, device=device + ) + slice_start_tensor = async_tensor_h2d( + slice_offset_lst, dtype=torch.uint64, device=device ) else: slice_start_tensor = slice_offset_lst[0] @@ -129,10 +137,19 @@ def _get_lora_b_ptr( same_stride = True else: - lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) - lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) - lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) - hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device) + # Pinned CPU + non_blocking H2D to avoid blocking copies. + lora_strides_d0_tensor = async_tensor_h2d( + lora_strides_d0, dtype=torch.int64, device=device + ) + lora_strides_d1_tensor = async_tensor_h2d( + lora_strides_d1, dtype=torch.int64, device=device + ) + lora_strides_d2_tensor = async_tensor_h2d( + lora_strides_d2, dtype=torch.int64, device=device + ) + hidden_sizes_tensor = async_tensor_h2d( + hidden_sizes, dtype=torch.int64, device=device + ) same_stride = False # MAX_N is the maximum hidden size among all the lora_b weights MAX_N = max(hidden_sizes) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 44d1dbd50728..ce477d026bd6 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -14,6 +14,7 @@ from vllm.lora.layers import LoRAMapping from vllm.lora.utils import get_captured_lora_counts from vllm.triton_utils import HAS_TRITON, triton +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.utils.math_utils import round_up if HAS_TRITON: @@ -83,9 +84,18 @@ def update_metadata( self.is_prefill = mapping.is_prefill self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size) - # Prepare cuda kernel metadata tensors - self.token_mapping_meta.prepare_tensors(self.token_lora_indices) - self.prompt_mapping_meta.prepare_tensors(self.sampler_indices) + # This method has two unavoidable GPU->CPU syncs given the current + # design: (1) the `torch.all(... == -1)` no-lora check below, and + # (2) `torch.unique(...)` + reading `lora_ids.size(0)` as a Python + # int further down. Both ultimately stem from needing facts about + # `token_lora_mapping`'s contents on the host (is everything -1? + # how many distinct loras?). TODO: compute these on CPU upstream + # in `convert_mapping` where the mapping is still a Python list, + # then pass the results in. + with gpu_sync_allowed(): + # Prepare cuda kernel metadata tensors + self.token_mapping_meta.prepare_tensors(self.token_lora_indices) + self.prompt_mapping_meta.prepare_tensors(self.sampler_indices) def add_shrink( self, diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index 584745f86b1a..8cf5f1a176ef 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -5,6 +5,8 @@ import torch +from vllm.utils.torch_utils import async_tensor_h2d + if TYPE_CHECKING: # avoid circuit import from vllm.lora.layers import LoRAMapping @@ -110,8 +112,8 @@ def convert_mapping( embedding_indices, ] - indices = torch.tensor(indices_list, dtype=torch.long, device=device) - prompt_mapping_tensor = torch.tensor( + indices = async_tensor_h2d(indices_list, dtype=torch.long, device=device) + prompt_mapping_tensor = async_tensor_h2d( prompt_mapping, dtype=torch.long, device=device ) embeddings_indices = torch.stack( diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 5b9bf2d76fbb..635b33762298 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -18,6 +18,7 @@ from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path +from vllm.utils.gpu_sync_debug import gpu_sync_allowed logger = init_logger(__name__) @@ -208,13 +209,21 @@ def _apply_adapters(self, adapter_requests: set[Any]) -> None: def add_adapter(self, adapter_request: Any) -> bool: if adapter_request.adapter_id in self.list_adapters(): return False - loaded_adapter = self._load_adapter(adapter_request) - loaded = self._adapter_manager.add_adapter(loaded_adapter) - self._adapter_manager.activate_adapter(loaded_adapter.id) + # One-time per adapter: load may sync (tensorizer) and + # `activate_adapter` eventually calls `set_lora` / `reset_lora` + # which write per-adapter scalars to GPU buffers. + with gpu_sync_allowed(): + loaded_adapter = self._load_adapter(adapter_request) + loaded = self._adapter_manager.add_adapter(loaded_adapter) + self._adapter_manager.activate_adapter(loaded_adapter.id) return loaded def remove_adapter(self, adapter_id: int) -> bool: - return self._adapter_manager.remove_adapter(adapter_id) + # Adapter removal calls `reset_lora` which does small per-adapter + # scalar-index writes on GPU buffers (e.g. `adapter_enabled[i] = 0`); + # one-time per adapter. + with gpu_sync_allowed(): + return self._adapter_manager.remove_adapter(adapter_id) def remove_all_adapters(self): self._adapter_manager.remove_all_adapters() @@ -279,24 +288,32 @@ def add_adapter(self, lora_request: LoRARequest) -> bool: # evicting any existing adapters. # This may cause the # of loaded lora adapters to very temporarily # exceed `--max-cpu-loras`. - lora = self._load_adapter(lora_request) - - # Remove the existing adapter if it exists - # Use case for LoRA inplace - self._adapter_manager.remove_adapter(lora.id) - - # Loading succeeded, now check if we will exceed cache capacity and - # evict if the oldest adapter if so - if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: - assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager) - self._adapter_manager.remove_oldest_adapter() - # Then add the new adapter to the cache - loaded = self._adapter_manager.add_adapter(lora) + # Adapter loading may sync (e.g. tensorizer's H2D weight copy) and + # the subsequent `reset_lora` / `set_lora` bookkeeping does small + # per-adapter scalar writes to GPU buffers. These are one-time + # per adapter lifecycle event, so allow syncs for the whole block. + with gpu_sync_allowed(): + lora = self._load_adapter(lora_request) + + # Remove the existing adapter if it exists + # Use case for LoRA inplace + self._adapter_manager.remove_adapter(lora.id) + + # Loading succeeded, now check if we will exceed cache capacity + # and evict if the oldest adapter if so + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager) + self._adapter_manager.remove_oldest_adapter() + # Then add the new adapter to the cache + loaded = self._adapter_manager.add_adapter(lora) else: # If the lora is already loaded, just touch it to # update its position in the caches loaded = ( self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None ) - self._adapter_manager.activate_adapter(lora_request.lora_int_id) + # `activate_adapter` eventually calls `set_lora` / `reset_lora` which + # write per-adapter scalars to GPU buffers; allow the one-time sync. + with gpu_sync_allowed(): + self._adapter_manager.activate_adapter(lora_request.lora_int_id) return loaded diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 0e476755201e..5d738548b046 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -361,6 +361,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): initial_state_idx=block_idx_last_computed_token_p, num_computed_tokens=num_computed_tokens_p, block_size_to_align=mamba_block_size, + metadata=attn_metadata, ) # 3. State Space Model sequence transformations. discrete_time_step_p, B_p, C_p = self._ssm_transform( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 2b4b1934f9b3..f05174ae4b03 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -761,17 +761,27 @@ def conv_ssm_forward( # then chunk_stride = 2 chunk_stride = mamba_block_size // chunk_size + # The per-sequence loop below uses these as Python scalars + # (slice bounds, `if == 0` tests, `%`, etc). Pull them to + # CPU once so each iteration doesn't force a D2H sync. + # These are small (per-request) and are known-unavoidable + # D2H reads; allow the one-shot sync instead of paying a + # per-iteration D2H cost. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + + with gpu_sync_allowed(): + block_idx_first_cpu = block_idx_first_scheduled_token_p.tolist() + block_idx_last_cpu = block_idx_last_scheduled_token_p.tolist() + num_computed_tokens_p_cpu = num_computed_tokens_p.tolist() + last_chunk_indices_p_cpu = last_chunk_indices_p.tolist() + # Save state for sequences with more than just final state for seq_idx in range(num_prefills): # Block index for the first scheduled token - block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[ - seq_idx - ] + block_idx_first_scheduled_token = block_idx_first_cpu[seq_idx] # Block index for the last scheduled token - block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[ - seq_idx - ] + block_idx_last_scheduled_token = block_idx_last_cpu[seq_idx] # Number of blocks that need to be written n_blocks_to_fill = ( @@ -792,7 +802,7 @@ def conv_ssm_forward( if seq_idx == 0: first_chunk = 0 else: - first_chunk = 1 + last_chunk_indices_p[seq_idx - 1] + first_chunk = 1 + last_chunk_indices_p_cpu[seq_idx - 1] # First chunk that is aligned on the mamba block boundary first_aligned_chunk = first_chunk + chunk_stride - 1 @@ -800,7 +810,7 @@ def conv_ssm_forward( # Calculate the number of computed tokens that were not # already cached num_unaligned_computed_tokens = ( - num_computed_tokens_p[seq_idx] % mamba_block_size + num_computed_tokens_p_cpu[seq_idx] % mamba_block_size ) if num_unaligned_computed_tokens > 0: diff --git a/vllm/model_executor/layers/pooler/seqwise/methods.py b/vllm/model_executor/layers/pooler/seqwise/methods.py index b967ff4ede7b..82170b5fbdc4 100644 --- a/vllm/model_executor/layers/pooler/seqwise/methods.py +++ b/vllm/model_executor/layers/pooler/seqwise/methods.py @@ -68,21 +68,23 @@ def forward( "partial prefill not supported with MEAN pooling" ) - prompt_lens = pooling_cursor.prompt_lens_cpu.to( - hidden_states.device, dtype=torch.int64, non_blocking=True - ) - - num_seqs = prompt_lens.numel() + prompt_lens_cpu = pooling_cursor.prompt_lens_cpu + num_seqs = prompt_lens_cpu.numel() hidden_size = hidden_states.shape[-1] if num_seqs == 0: # early return for empty batch return hidden_states.new_empty((0, hidden_size), dtype=torch.float32) - # eg. [2, 1, 3] -> [0, 0, 1, 2, 2, 2] + # Build segment_ids on CPU so repeat_interleave doesn't need to sync + # GPU->CPU to learn its data-dependent output length, then upload + # non-blocking. eg. [2, 1, 3] -> [0, 0, 1, 2, 2, 2] segment_ids = torch.repeat_interleave( - torch.arange(num_seqs, device=hidden_states.device, dtype=torch.long), - prompt_lens, + torch.arange(num_seqs, dtype=torch.long), + prompt_lens_cpu, + ).to(hidden_states.device, non_blocking=True) + prompt_lens = prompt_lens_cpu.to( + hidden_states.device, dtype=torch.int64, non_blocking=True ) segment_sums = torch.zeros( (num_seqs, hidden_size), diff --git a/vllm/model_executor/layers/pooler/tokwise/methods.py b/vllm/model_executor/layers/pooler/tokwise/methods.py index d3fefb745cfe..1a6579e75ae9 100644 --- a/vllm/model_executor/layers/pooler/tokwise/methods.py +++ b/vllm/model_executor/layers/pooler/tokwise/methods.py @@ -47,17 +47,12 @@ def forward( pooling_metadata: PoolingMetadata, ) -> list[TokenPoolingMethodOutputItem]: pooling_cursor = pooling_metadata.get_pooling_cursor() - split_sizes = pooling_cursor.num_scheduled_tokens_cpu.tolist() - if split_sizes: - # DispatchPooler passes the full hidden_states tensor. - # slice out the subgroup once, then split it by - # per-request token counts - group_start = int(pooling_cursor.first_token_indices_gpu[0].item()) - group_end = int(pooling_cursor.last_token_indices_gpu[-1].item()) + 1 - hidden_states_group = hidden_states[group_start:group_end] - hidden_states_lst = list(hidden_states_group.split(split_sizes)) - else: - hidden_states_lst = [] + # Use the already-CPU num_scheduled_tokens tensor so `.tolist()` + # doesn't trigger a GPU->CPU sync. torch.split produces the same + # consecutive slices as indexing with first/last per-sequence indices. + hidden_states_lst = list( + torch.split(hidden_states, pooling_cursor.num_scheduled_tokens_cpu.tolist()) + ) if not self.enable_chunked_prefill: return hidden_states_lst @@ -95,12 +90,15 @@ def forward( pooling_metadata: PoolingMetadata, ) -> list[TokenPoolingMethodOutputItem]: pooled_data_lst = super().forward(hidden_states, pooling_metadata) - prompt_token_ids = pooling_metadata.get_prompt_token_ids() + # Use the CPU copy of prompt_token_ids so the step_tag_id mask can be + # resolved to indices without a device->host sync from boolean + # indexing. + prompt_token_ids_cpu = pooling_metadata.get_prompt_token_ids_cpu() pooling_params = pooling_metadata.pooling_params pooled_data = list[torch.Tensor | None]() - for data, token_id, pooling_param in zip( - pooled_data_lst, prompt_token_ids, pooling_params + for data, token_id_cpu, pooling_param in zip( + pooled_data_lst, prompt_token_ids_cpu, pooling_params ): # for unfinished chunked prefill if data is None: @@ -113,7 +111,9 @@ def forward( data = data[:, returned_token_ids] if step_tag_id is not None: - data = data[token_id == step_tag_id] + idx_cpu = (token_id_cpu == step_tag_id).nonzero(as_tuple=True)[0] + idx = idx_cpu.to(data.device, non_blocking=True) + data = data[idx] pooled_data.append(data) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 729924663646..8202f0deac48 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -291,6 +291,11 @@ def _apply_8bit_weight( # only load the bitsandbytes module when needed from bitsandbytes import MatmulLtState, matmul + # BnB's `int8_vectorwise_quant` branches on `outliers.any()` which + # forces a D2H sync; this is third-party code we can't refactor, + # and it's a per-layer fixed cost in int8 quant mode. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + original_type = x.dtype original_shape = x.shape reshape_after_matmul = False @@ -334,9 +339,10 @@ def _apply_8bit_weight( new_x = bf_x.unsqueeze(0) - out[:, current_index : current_index + output_size] = matmul( - new_x, qweight[offsets[i] : offsets[i + 1]], state=matmul_states[i] - ) + with gpu_sync_allowed(): + out[:, current_index : current_index + output_size] = matmul( + new_x, qweight[offsets[i] : offsets[i + 1]], state=matmul_states[i] + ) current_index += output_size diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 01854b96d56f..f34e1c97ba9f 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -559,13 +559,14 @@ def _encode_token_type_ids( def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: - ids_mask = ( - torch.ones_like(input_ids, dtype=torch.int32, device=input_ids.device) - << TOKEN_TYPE_SHIFT - ) - tokens_mask = ids_mask.bitwise_not() - - token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT + # Use scalar masks rather than a `torch.ones_like(...)` same-shape tensor: + # the uniform-value tensor would otherwise trip inductor's + # constant_fold_uniform_value pass (which calls `.item()` at compile time) + # and it also wastes a per-step GPU allocation. + ids_mask = 1 << TOKEN_TYPE_SHIFT + tokens_mask = ~ids_mask + + token_type_ids = (input_ids & ids_mask) >> TOKEN_TYPE_SHIFT input_ids.bitwise_and_(tokens_mask) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index a150428baff4..2fcbc819ea10 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -58,6 +58,7 @@ ) from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import async_tensor_h2d from .interfaces import ( MultiModalEmbeddings, @@ -821,8 +822,20 @@ def img2bpe_mapping_tensor(self): def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: device = img_batch.device - img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] - return img_tokens.to(device) + # Cache a per-device copy of the (small, static) mapping tensor so we + # can index entirely on `device` instead of forcing a D2H on + # `img_batch` and an H2D on the result. + cache = getattr(self, "_img2bpe_mapping_cache", None) + if cache is None: + cache = {} + object.__setattr__(self, "_img2bpe_mapping_cache", cache) + mapping_on_device = cache.get(device) + if mapping_on_device is None: + mapping_on_device = self.img2bpe_mapping_tensor.pin_memory().to( + device, non_blocking=True + ) + cache[device] = mapping_on_device + return mapping_on_device[img_batch] class ChameleonModel(nn.Module): @@ -1017,8 +1030,23 @@ def compute_logits( # Disallow image tokens which does not include special # begin-image and end-image tokens if logits is not None: - image_tokens = self.model.vocabulary_mapping.image_tokens - logits[:, image_tokens] = torch.finfo(logits.dtype).min + # Cache a per-device index tensor for the (static) image-token + # set, and use `index_fill_` instead of advanced-index assign + # so the scatter runs entirely on device (no host roundtrip + # for the scalar fill value or for the Python-list indices). + cache = getattr(self, "_image_tokens_index_cache", None) + if cache is None: + cache = {} + self._image_tokens_index_cache = cache + image_tokens_idx = cache.get(logits.device) + if image_tokens_idx is None: + image_tokens_idx = async_tensor_h2d( + self.model.vocabulary_mapping.image_tokens, + dtype=torch.long, + device=logits.device, + ) + cache[logits.device] = image_tokens_idx + logits.index_fill_(1, image_tokens_idx, torch.finfo(logits.dtype).min) return logits diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 0f059b6d1340..e6814916eb58 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -305,7 +305,9 @@ def _get_mm_fields_config( return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), - num_patches=MultiModalFieldConfig.batched("image"), + # `num_patches` is only consumed as Python split sizes via + # `.tolist()` in `_process_image_input`; keep it CPU-resident. + num_patches=MultiModalFieldConfig.batched("image", keep_on_cpu=True), ) def _get_prompt_updates( diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 4e9838805e58..f572168e5c37 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -54,6 +54,7 @@ ) from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import async_tensor_h2d from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription from .utils import ( @@ -635,10 +636,18 @@ def _process_audio_input( # We handle both cases: # - If fewer tokens: pad with the embedding of the last vocab token # - If more tokens: truncate to the expected count - # TODO precompute and cache padding - audio_padding_toks = torch.tensor( - [[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device - ) + # Cache the single-scalar padding-token tensor per-device to avoid a + # synchronous H2D tensor construction on every forward. + cache = getattr(self, "_audio_padding_toks_cache", None) + if cache is None: + cache = {} + self._audio_padding_toks_cache = cache + audio_padding_toks = cache.get(audio_features.device) + if audio_padding_toks is None: + audio_padding_toks = async_tensor_h2d( + [[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device + ) + cache[audio_features.device] = audio_padding_toks audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks) audio_features = torch.where( audio_mask.unsqueeze(-1), audio_padding_embs, audio_features diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 85f422342a95..52eab15afcd6 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -723,6 +723,10 @@ def rot_pos_emb( # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + # pos_ids was built on CPU above; upload non-blocking so the two + # gathers below don't force a synchronous H2D for the indices. + pos_ids = pos_ids.to(cos.device, non_blocking=True) + cos_combined = cos[pos_ids].flatten(1) sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined, pos_ids diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 036b92ed8808..1b8de1cc4b20 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -143,7 +143,8 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict( input_features=MultiModalFieldConfig.batched("audio"), - audio_embed_sizes=MultiModalFieldConfig.batched("audio"), + # Only consumed as Python split sizes / scalars; keep CPU-resident. + audio_embed_sizes=MultiModalFieldConfig.batched("audio", keep_on_cpu=True), ) def _get_prompt_updates( @@ -717,13 +718,17 @@ def _build_input_features_mask( torch.Tensor: Mask of shape (bsz, num_features) to be applied to the audio features prior to splitting the audio embeddings. """ - most_audio_features = torch.max(audio_embed_sizes).item() - mask_indices = torch.arange( - most_audio_features, - device=audio_embed_sizes.device, - ).view(1, -1) + # `audio_embed_sizes` is CPU-resident (see `_get_mm_fields_config`). + # Build the mask on CPU and pin+non-blocking upload to the encoder's + # device — the boolean mask is later used to index GPU + # `projected_embeds` in `_process_audio_input`. + most_audio_features = int(torch.max(audio_embed_sizes)) + mask_indices = torch.arange(most_audio_features).view(1, -1) input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1) - return input_features_mask + target_device = self.encoder.input_linear.weight.device + if target_device == input_features_mask.device: + return input_features_mask + return input_features_mask.pin_memory().to(target_device, non_blocking=True) def _pad_and_stack_input_features( self, @@ -777,8 +782,13 @@ def _process_audio_input( encoder_embeds = self.encoder(audio_input["input_features"]) # [bsz, , 4096] projected_embeds = self.projector(encoder_embeds) - # Apply mask on variable length audio features - masked_embeds = projected_embeds[audio_input["input_features_mask"]] + # Apply mask on variable length audio features. Boolean-mask indexing + # has a data-dependent output shape and always syncs on CUDA; this + # runs once per MM encoder call. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + + with gpu_sync_allowed(): + masked_embeds = projected_embeds[audio_input["input_features_mask"]] # Split variable length features into a tuple return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 7db2e823fbc6..3cad59c074ee 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -39,6 +39,7 @@ ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model @@ -95,10 +96,18 @@ def forward( size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 ) - for batch_idx, p_attn_mask in enumerate(patch_attention_mask): - if tgt_sizes is not None: - nb_patches_h = tgt_sizes[batch_idx][0] - nb_patches_w = tgt_sizes[batch_idx][1] + # This loop runs CPU-side (position_ids/boundaries are on CPU). + # Bring the per-image inputs to CPU once up front so the loop doesn't + # pay a per-iteration D2H sync for `nb_patches_*` (via `.sum()` on a + # GPU `p_attn_mask`) or for the boolean indexing into `position_ids`. + with gpu_sync_allowed(): + patch_attention_mask_cpu = patch_attention_mask.cpu() + tgt_sizes_cpu = tgt_sizes.cpu() if tgt_sizes is not None else None + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask_cpu): + if tgt_sizes_cpu is not None: + nb_patches_h = tgt_sizes_cpu[batch_idx][0] + nb_patches_w = tgt_sizes_cpu[batch_idx][1] else: nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() @@ -113,8 +122,12 @@ def forward( pos_ids = ( bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w ).flatten() - position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids - position_ids = position_ids.to(self.position_embedding.weight.device) + position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids + # `position_ids` is a CPU tensor built above; pin+non_blocking upload + # to avoid a synchronous H2D copy. + position_ids = position_ids.to( + self.position_embedding.weight.device, non_blocking=True + ) embeddings += self.position_embedding(position_ids) return embeddings diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index d3ffdd4cf29a..6d9b01aa107a 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -51,6 +51,7 @@ PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import ( @@ -355,7 +356,9 @@ def _get_mm_fields_config( "image", num_patches ), image_embeds=MultiModalFieldConfig.batched("image"), - num_patches=MultiModalFieldConfig.batched("image"), + # Only consumed as Python split sizes in `_process_image_input`; + # keep CPU-resident to avoid the `.tolist()` D2H sync. + num_patches=MultiModalFieldConfig.batched("image", keep_on_cpu=True), ) def _get_prompt_updates( @@ -499,11 +502,11 @@ def image_pixels_to_features( real_images_inds = (pixel_values == 0.0).sum( dim=(-1, -2, -3) ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + with gpu_sync_allowed(): + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() patch_size = self.config.vision_config.patch_size patches_subgrid = pixel_attention_mask.unfold( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index c8611a499362..262959685790 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -247,9 +247,16 @@ def _get_image_fields_config(self, hf_inputs: BatchFeature): pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( "image", image_num_patches ), - image_num_patches=MultiModalFieldConfig.batched("image"), + # Only consumed as Python split sizes in `_process_vision_input`; + # keep CPU-resident to avoid D2H syncs. + image_num_patches=MultiModalFieldConfig.batched("image", keep_on_cpu=True), image_embeds=MultiModalFieldConfig.batched("image"), - image_token_id=MultiModalFieldConfig.shared("image", num_images), + # Scalar metadata consumed only as a Python int (via .item()) by + # `_parse_and_validate_image_input`; keep on CPU to avoid a + # pointless H2D and the downstream `.item()` sync on GPU. + image_token_id=MultiModalFieldConfig.shared( + "image", num_images, keep_on_cpu=True + ), ) def _get_mm_fields_config( @@ -475,7 +482,8 @@ def _get_video_fields_config(self, hf_inputs: BatchFeature): pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( "video", video_num_patches ), - video_num_patches=MultiModalFieldConfig.batched("video"), + # Only consumed as Python split sizes; keep CPU-resident. + video_num_patches=MultiModalFieldConfig.batched("video", keep_on_cpu=True), video_token_id=MultiModalFieldConfig.shared("video", num_videos), ) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 739c90a4292b..5385de4d8a06 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -214,7 +214,9 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), - image_sizes=MultiModalFieldConfig.batched("image"), + # Only consumed as Python ints via `.tolist()` for grid-shape math; + # keep CPU-resident to avoid D2H syncs. + image_sizes=MultiModalFieldConfig.batched("image", keep_on_cpu=True), image_embeds=MultiModalFieldConfig.batched("image"), ) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 638d9ba9d892..9534bf3dd3e9 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -316,7 +316,9 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), - image_sizes=MultiModalFieldConfig.batched("image"), + # Only consumed as Python ints via `.tolist()` for grid-shape math; + # keep CPU-resident to avoid D2H syncs. + image_sizes=MultiModalFieldConfig.batched("image", keep_on_cpu=True), image_embeds=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.batched("video"), ) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 95689ef321b8..197211cc14c1 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -56,6 +56,7 @@ ResolvedPromptUpdate, ) from vllm.sequence import IntermediateTensors +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel @@ -266,9 +267,11 @@ def hd_feature_transform(self, image_features, image_sizes): ) batch_image_features_proj = [] + with gpu_sync_allowed(): + image_sizes_list = image_sizes.tolist() # need a for loop to process each image because of different image sizes # (patch arrangement is different for each image) - for i, img_size in enumerate(image_sizes): + for i, img_size in enumerate(image_sizes_list): h, w = img_size h_crop = h // 336 w_crop = w // 336 diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index c3b09ed590dd..15190d8f6748 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -586,7 +586,11 @@ def forward_embeddings( seq_len, batch_size, self.chunk_size, self.left_chunk ) device = xs_pad.device - enc_streaming_mask = enc_streaming_mask.to(device) + # `enc_streaming_mask` is a CPU tensor; pin + non_blocking upload to + # avoid a synchronous H2D copy. + enc_streaming_mask = enc_streaming_mask.contiguous().to( + device, non_blocking=True + ) xs_pad = xs_pad.to(device) input_tensor = xs_pad @@ -605,7 +609,9 @@ def forward_embeddings( seq_len, batch_size, chunk_size_nc, left_chunk_nc ) if device.type != "cpu": - enc_streaming_mask_nc = enc_streaming_mask_nc.to(device) + enc_streaming_mask_nc = enc_streaming_mask_nc.contiguous().to( + device, non_blocking=True + ) if masks is not None: hs_mask_nc = masks & enc_streaming_mask_nc else: @@ -917,7 +923,11 @@ def calculate_hs_mask( enc_streaming_mask = self._streaming_mask( max_audio_length, batch_size, self.chunk_size, self.left_chunk ) - enc_streaming_mask = enc_streaming_mask.to(device) + # `enc_streaming_mask` is a CPU tensor; pin + non_blocking upload to + # avoid a synchronous H2D copy. + enc_streaming_mask = enc_streaming_mask.contiguous().to( + device, non_blocking=True + ) if mask is None: return enc_streaming_mask diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 14f4c424bb31..730fe478f5a4 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -278,15 +278,15 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): image_embeds=MultiModalFieldConfig.flat_from_sizes( "image", image_embed_grid_sizes ), - image_grid_thw=MultiModalFieldConfig.batched("image"), + image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( "video", video_grid_sizes ), video_embeds=MultiModalFieldConfig.flat_from_sizes( "video", video_embed_grid_sizes ), - video_grid_thw=MultiModalFieldConfig.batched("video"), - second_per_grid_ts=MultiModalFieldConfig.batched("video"), + video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True), + second_per_grid_ts=MultiModalFieldConfig.batched("video", keep_on_cpu=True), use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos), ) @@ -960,12 +960,20 @@ def _process_audio_input( self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths) ) - audio_outputs = self.audio_tower( - input_features.to(self.audio_tower.dtype), - feature_lens=audio_feature_lengths, - aftercnn_lens=audio_feat_lengths, - ) - return audio_outputs.last_hidden_state.split(audio_output_lengths.tolist()) + # Upstream transformers `qwen2_5_omni` audio tower does + # `torch.full((chunk_num.sum(),), ...)` which syncs on a GPU scalar + # reduction. Third-party; suppress at the integration boundary. The + # subsequent `.tolist()` on GPU output lengths is also unavoidable. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + + with gpu_sync_allowed(): + audio_outputs = self.audio_tower( + input_features.to(self.audio_tower.dtype), + feature_lens=audio_feature_lengths, + aftercnn_lens=audio_feat_lengths, + ) + split_sizes = audio_output_lengths.tolist() + return audio_outputs.last_hidden_state.split(split_sizes) def _process_image_input( self, image_input: Qwen2_5_VLImageInputs @@ -1450,12 +1458,19 @@ def embed_input_ids( video_token_id = self.config.video_token_index audio_token_id = self.config.audio_token_index - input_ids_cpu = input_ids.cpu() - is_video = is_multimodal & (input_ids_cpu == video_token_id) - is_audio = is_multimodal & (input_ids_cpu == audio_token_id) + # Branch on a Python scalar below; the `input_ids.cpu()` and + # `.item()` reductions are unavoidable without refactoring the + # interleave-merge to be fully GPU-resident. Run under an + # allowed-sync block since this happens once per MM embed call. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + + with gpu_sync_allowed(): + input_ids_cpu = input_ids.cpu() + is_video = is_multimodal & (input_ids_cpu == video_token_id) + is_audio = is_multimodal & (input_ids_cpu == audio_token_id) - num_video = is_video.sum().item() - num_audio = is_audio.sum().item() + num_video = is_video.sum().item() + num_audio = is_audio.sum().item() if check_interleaved_audio_video(is_video, is_audio, num_video, num_audio): inputs_embeds = self._embed_text_input_ids( diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c11684b4b89b..7ee8adc53a4a 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -84,8 +84,10 @@ from vllm.sequence import IntermediateTensors from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import async_tensor_h2d from vllm.v1.attention.backends.registry import AttentionBackendEnum +from ...utils.gpu_sync_debug import gpu_sync_allowed from .interfaces import ( MultiModalEmbeddings, SupportsEagle, @@ -677,6 +679,10 @@ def rotary_pos_emb_thw(self, t, h, w): # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_size) + # pos_ids was built on CPU above; upload non-blocking so the two + # gathers below don't force a synchronous H2D for the indices. + pos_ids = pos_ids.to(cos.device, non_blocking=True) + cos_combined = cos[pos_ids].flatten(1) sin_combined = sin[pos_ids].flatten(1) @@ -735,9 +741,12 @@ def get_rope_by_thw(self, t, h, w): window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w) cos_thw, sin_thw = self.rotary_pos_emb_thw(t, h, w) - cos_thw = cos_thw[window_index_thw, :, :] + # window_index_thw is built on CPU; upload non-blocking so the gathers + # below don't force a synchronous H2D for the index tensor. + window_index_thw_dev = window_index_thw.to(cos_thw.device, non_blocking=True) + cos_thw = cos_thw[window_index_thw_dev, :, :] cos_thw = cos_thw.flatten(start_dim=0, end_dim=1) - sin_thw = sin_thw[window_index_thw, :, :] + sin_thw = sin_thw[window_index_thw_dev, :, :] sin_thw = sin_thw.flatten(start_dim=0, end_dim=1) cu_seqlens_thw = torch.repeat_interleave( @@ -926,7 +935,9 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict( **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs), - second_per_grid_ts=MultiModalFieldConfig.batched("video"), + # Only consumed as Python scalars via `.item()` in the EVS path; + # keep CPU-resident. + second_per_grid_ts=MultiModalFieldConfig.batched("video", keep_on_cpu=True), ) def _call_hf_processor( @@ -1249,7 +1260,11 @@ def _postprocess_image_embeds_evs( grid_thw_list = grid_thw.tolist() image_embeds_out = [] for emb, size in zip(image_embeds_split, grid_thw_list): - positions = compute_mrope_for_media(size, merge_size).to(emb.device) + # `compute_mrope_for_media` returns a CPU tensor; pin + + # non-blocking upload to avoid synchronous H2D. + positions = compute_mrope_for_media(size, merge_size).to( + emb.device, non_blocking=True + ) emb = torch.cat([emb, positions], dim=1) image_embeds_out.append(emb) image_embeds_split = image_embeds_out @@ -1331,10 +1346,15 @@ def _postprocess_video_embeds_evs( merge_size, tokens_per_second=tokens_per_second, video_second_per_grid=video_second_per_grid_t.item(), - ).to(emb.device) + ).to(emb.device, non_blocking=True) + + # Boolean-mask indexing has a data-dependent output shape and + # always syncs on CUDA; runs once per video in the EVS path. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed - emb = emb[retention_mask] - positions = positions[retention_mask] + with gpu_sync_allowed(): + emb = emb[retention_mask] + positions = positions[retention_mask] emb = torch.cat([emb, positions], dim=1) video_embeds_out.append(emb) return tuple(video_embeds_out) @@ -1376,23 +1396,24 @@ def recompute_mrope_positions( else mrope_positions.device ) - # Tensors - input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) + # Tensors. + input_ids_t = async_tensor_h2d(input_ids, dtype=torch.long, device=device) mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] mm_embeddings_pos = [ mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings ] - positions, mrope_positions_delta = recompute_mrope_positions( - input_ids_t, - mm_embeddings_pos, - mrope_positions, - num_computed_tokens, - vision_start_token_id, - image_token_id, - video_token_id, - ) + with gpu_sync_allowed(): + positions, mrope_positions_delta = recompute_mrope_positions( + input_ids_t, + mm_embeddings_pos, + mrope_positions, + num_computed_tokens, + vision_start_token_id, + image_token_id, + video_token_id, + ) return tuple(mm_embeddings_out), positions, mrope_positions_delta diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index e7e8d74714cd..13086fc2395c 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -60,6 +60,7 @@ PromptUpdate, ) from vllm.sequence import IntermediateTensors +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -439,17 +440,24 @@ def _process_audio_input( num_audios, max_audio_tokens, embed_dim = audio_features.shape audio_output_lengths = audio_output_lengths.unsqueeze(1) audio_features_mask = ( - torch.arange(max_audio_tokens) - .expand(num_audios, max_audio_tokens) - .to(audio_output_lengths.device) + torch.arange(max_audio_tokens, device=audio_output_lengths.device).expand( + num_audios, max_audio_tokens + ) < audio_output_lengths ) - masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim) + # The boolean-mask gather below and the `.tolist()` that feeds + # `torch.split` both force GPU->CPU syncs (output sizes are + # data-dependent). Suppress the sync check here; restructuring the + # downstream contract to avoid these syncs would be a broader refactor. + with gpu_sync_allowed(): + masked_audio_features = audio_features[audio_features_mask].view( + -1, embed_dim + ) - # Split to tuple of embeddings for individual audio input. - return torch.split( - masked_audio_features, audio_output_lengths.flatten().tolist() - ) + # Split to tuple of embeddings for individual audio input. + return torch.split( + masked_audio_features, audio_output_lengths.flatten().tolist() + ) def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 176f45781081..75017637123c 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -632,6 +632,10 @@ def rot_pos_emb( # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + # pos_ids was built on CPU above; upload non-blocking so the two + # gathers below don't force a synchronous H2D for the indices. + pos_ids = pos_ids.to(cos.device, non_blocking=True) + cos_combined = cos[pos_ids].flatten(1) sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined diff --git a/vllm/model_executor/models/qwen3_asr.py b/vllm/model_executor/models/qwen3_asr.py index 950beba77541..ad0d00fa207b 100644 --- a/vllm/model_executor/models/qwen3_asr.py +++ b/vllm/model_executor/models/qwen3_asr.py @@ -374,7 +374,13 @@ def _process_audio_input( feature_lens=audio_feature_lengths, aftercnn_lens=audio_output_lengths, ) - return audio_features.split(audio_output_lengths.tolist()) + # `.tolist()` on GPU output lengths forces a D2H sync; split sizes are + # data-dependent so this is unavoidable. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + + with gpu_sync_allowed(): + split_sizes = audio_output_lengths.tolist() + return audio_features.split(split_sizes) def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 44e656ab820b..dc5e2ed23b04 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -80,6 +80,7 @@ ) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.utils.torch_utils import async_tensor_h2d from vllm.v1.attention.backends.registry import AttentionBackendEnum from .interfaces import ( @@ -427,11 +428,19 @@ def forward( feature_lens: torch.Tensor, aftercnn_lens: torch.Tensor, ): - # Compute chunk information + # Compute chunk information. `feature_lens` is small (per-audio) so + # pull to CPU once to compute total chunk count and split sizes + # without per-iteration D2H syncs. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() - chunk_lengths = torch.tensor( - [self.n_window * 2] * chunk_num.sum(), + with gpu_sync_allowed(): + total_chunks = int(chunk_num.sum()) + # `torch.tensor([...], device=cuda)` forces a blocking H2D; use + # pinned CPU + non-blocking upload instead. + chunk_lengths = async_tensor_h2d( + [self.n_window * 2] * total_chunks, dtype=torch.long, device=feature_lens.device, ) @@ -440,15 +449,20 @@ def forward( chunk_lengths[chunk_lengths == 0] = self.n_window * 2 # Split input features into chunks and pad - chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + with gpu_sync_allowed(): + chunk_lengths_list = chunk_lengths.tolist() + chunk_list = input_features.T.split(chunk_lengths_list, dim=0) padded_feature = nn.utils.rnn.pad_sequence( chunk_list, batch_first=True ).transpose(1, 2) # Compute feature lengths after CNN feature_lens_after_cnn = self._get_cnn_output_lengths(chunk_lengths) - # Vectorized mask creation: avoid creating many small tensors - max_len_after_cnn = feature_lens_after_cnn.max().item() + # Vectorized mask creation: avoid creating many small tensors. The + # `.item()` below is unavoidable — `torch.arange` needs a Python int + # for the output size. Runs once per audio forward. + with gpu_sync_allowed(): + max_len_after_cnn = feature_lens_after_cnn.max().item() indices = torch.arange(max_len_after_cnn, device=padded_feature.device) padded_mask_after_cnn = indices.unsqueeze(0) < feature_lens_after_cnn.unsqueeze( 1 @@ -487,24 +501,28 @@ def forward( ) padded_embed = padded_embed + positional_embedding - # Extract valid hidden states and compute cu_seqlens - hidden_states = padded_embed[padded_mask_after_cnn] + # Boolean-mask indexing has a data-dependent output shape; syncs on + # CUDA. Runs once per audio forward. + with gpu_sync_allowed(): + hidden_states = padded_embed[padded_mask_after_cnn] + # `aftercnn_lens.tolist()` is an unavoidable D2H below. + aftercnn_lens_list = aftercnn_lens.tolist() # Compute cumulative sequence lengths for chunked attention cu_chunk_lens = [0] window_aftercnn = padded_mask_after_cnn.shape[-1] * ( self.n_window_infer // (self.n_window * 2) ) - # Use tolist() for efficient batch conversion from tensor to Python - for cnn_len in aftercnn_lens.tolist(): + for cnn_len in aftercnn_lens_list: num_full_chunks = cnn_len // window_aftercnn remainder = cnn_len % window_aftercnn cu_chunk_lens.extend([window_aftercnn] * num_full_chunks) if remainder: cu_chunk_lens.append(remainder) - cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum( - -1, dtype=torch.int32 - ) + # Build on pinned CPU then non-blocking upload. + cu_seqlens = async_tensor_h2d( + cu_chunk_lens, dtype=torch.int32, device=aftercnn_lens.device + ).cumsum(-1, dtype=torch.int32) max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) @@ -861,6 +879,10 @@ def rot_pos_emb(self, grid_thw): # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + # pos_ids was built on CPU above; upload non-blocking so the two + # gathers below don't force a synchronous H2D for the indices. + pos_ids = pos_ids.to(cos.device, non_blocking=True) + cos_combined = cos[pos_ids].flatten(1) sin_combined = sin[pos_ids].flatten(1) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 2575f634be1c..48d7bb0d3123 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -101,6 +101,8 @@ from vllm.utils.math_utils import round_up from vllm.v1.worker.encoder_cudagraph_defs import EncoderCudaGraphReplayBuffers +from ...utils.gpu_sync_debug import gpu_sync_allowed +from ...utils.torch_utils import async_tensor_h2d from .interfaces import ( MultiModalEmbeddings, SupportsEagle, @@ -2156,7 +2158,9 @@ def _postprocess_image_embeds_evs( grid_thw_list = grid_thw.tolist() image_embeds_out = [] for emb, size in zip(image_embeds_split, grid_thw_list): - positions = compute_mrope_for_media(size, merge_size).to(emb.device) + positions = compute_mrope_for_media(size, merge_size).to( + emb.device, non_blocking=True + ) positions = torch.cat( [ positions, @@ -2211,19 +2215,25 @@ def _postprocess_video_embeds_evs( spatial_merge_size=self.visual.spatial_merge_size, q=self.video_pruning_rate, ) - # Apply retention mask. - emb = emb[retention_mask] - - # Calculate the actual number of retained tokens per frame. - num_frames, rows, cols = ( - t, - h // merge_size, - w // merge_size, - ) - retention_mask_thw = retention_mask.reshape(num_frames, rows, cols) - num_tokens_per_frame = ( - retention_mask_thw.sum(dim=(1, 2)).long().tolist() - ) + # Boolean-mask indexing has a data-dependent output shape and + # always syncs on CUDA. The `.tolist()` below is also a sync + # but is required because `num_tokens_per_frame` is consumed + # downstream as Python ints. Runs once per video in the EVS + # path. + with gpu_sync_allowed(): + # Apply retention mask. + emb = emb[retention_mask] + + # Calculate the actual number of retained tokens per frame. + num_frames, rows, cols = ( + t, + h // merge_size, + w // merge_size, + ) + retention_mask_thw = retention_mask.reshape(num_frames, rows, cols) + num_tokens_per_frame = ( + retention_mask_thw.sum(dim=(1, 2)).long().tolist() + ) else: feature_size = emb.shape[0] // num_frames num_tokens_per_frame = [feature_size] * num_frames @@ -2387,14 +2397,19 @@ def _get_expanded_positions( input_tokens=unpruned_token_ids, mm_features=[mm_feature], )[0] - .to(device) + .to(device, non_blocking=True) .permute(1, 0) ) full_is_video_embed = unpruned_token_ids_tensor == embed_token_id - expanded_positions[is_video_embed, :3] = original_mrope[full_is_video_embed][ - retention_mask - ] - expanded_positions[~is_video_embed, :3] = original_mrope[~full_is_video_embed] + # Boolean-mask indexing has data-dependent output shapes and always + # syncs on CUDA; runs once per video in the EVS path. + with gpu_sync_allowed(): + expanded_positions[is_video_embed, :3] = original_mrope[ + full_is_video_embed + ][retention_mask] + expanded_positions[~is_video_embed, :3] = original_mrope[ + ~full_is_video_embed + ] expanded_positions[..., 3] = is_vision_start expanded_positions[..., 4] = is_video_embed @@ -2635,7 +2650,7 @@ def _recompute_mrope_positions( ) # Tensors - input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) + input_ids_t = async_tensor_h2d(input_ids, device=device, dtype=torch.long) mm_embeddings_out = [] mm_embeddings_pos = [] @@ -2653,15 +2668,16 @@ def _recompute_mrope_positions( torch.empty(5, 0, device=device, dtype=torch.long) ) - positions, mrope_positions_delta = recompute_mrope_positions( - input_ids_t, - mm_embeddings_pos, - mrope_positions, - num_computed_tokens, - vision_start_token_id, - image_token_id, - video_token_id, - ) + with gpu_sync_allowed(): + positions, mrope_positions_delta = recompute_mrope_positions( + input_ids_t, + mm_embeddings_pos, + mrope_positions, + num_computed_tokens, + vision_start_token_id, + image_token_id, + video_token_id, + ) return tuple(mm_embeddings_out), positions, mrope_positions_delta diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index ce3a260d0ef6..f5a24881bb47 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -1112,31 +1112,37 @@ def _flip_sequences_by_position_ids( if len(features) == 1: return features - # Detect sequence boundaries where position_ids decrease + # Detect sequence boundaries where position_ids decrease. + # `torch.where(boundary_mask)` and `repeat_interleave(lengths)` both + # have data-dependent output shapes that always sync on CUDA. This + # runs once per pooling call. Compute boundaries on CPU (single D2H + # for `boundary_mask`) then upload the final flip-index tensor. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + position_diffs = position_ids[1:] - position_ids[:-1] boundary_mask = position_diffs <= 0 - boundary_indices = torch.cat( - [ - torch.tensor([0], device=features.device), - torch.where(boundary_mask)[0] + 1, - torch.tensor([len(features)], device=features.device), - ] + with gpu_sync_allowed(): + boundary_mid_cpu = torch.where(boundary_mask.cpu())[0] + 1 + zero = torch.zeros(1, dtype=boundary_mid_cpu.dtype) + end = torch.full((1,), len(features), dtype=boundary_mid_cpu.dtype) + boundary_indices_cpu = torch.cat([zero, boundary_mid_cpu, end]) + lengths_cpu = boundary_indices_cpu[1:] - boundary_indices_cpu[:-1] + starts_cpu = boundary_indices_cpu[:-1] + ends_cpu = boundary_indices_cpu[1:] + sequence_ids_cpu = torch.arange( + len(lengths_cpu), dtype=boundary_mid_cpu.dtype + ).repeat_interleave(lengths_cpu) + current_positions_cpu = torch.arange( + len(features), dtype=boundary_mid_cpu.dtype ) - - # For each sequence [start, end), position i flips to: start + end - 1 - i - lengths = boundary_indices[1:] - boundary_indices[:-1] - starts = boundary_indices[:-1] - ends = boundary_indices[1:] - - # Assign sequence ID to each element - sequence_ids = torch.arange( - len(lengths), device=features.device - ).repeat_interleave(lengths) - - # Calculate flipped indices for all positions at once - current_positions = torch.arange(len(features), device=features.device) - flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions + flip_indices_cpu = ( + starts_cpu[sequence_ids_cpu] + + ends_cpu[sequence_ids_cpu] + - 1 + - current_positions_cpu + ) + flip_indices = flip_indices_cpu.to(features.device, non_blocking=True) return features[flip_indices] diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index e863b0bb51c7..88184f5423f6 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -61,6 +61,7 @@ TimingContext, ) from vllm.sequence import IntermediateTensors +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal from .interfaces_base import attn_type @@ -277,7 +278,12 @@ def forward( inputs_embeds: torch.Tensor | None = None, **kwargs: object, ): - model_output = self.inference_runner.forward(**kwargs) + # terratorch's forward has internal GPU syncs we can't fix from + # here (e.g. prithvi_mae._get_1d_sincos_embed_from_grid_torch + # does `torch.arange(...).to(pos.device)` every call). Suppress + # sync checking at this integration boundary. + with gpu_sync_allowed(): + model_output = self.inference_runner.forward(**kwargs) return model_output.output def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index ab6ba91d2438..857595df2916 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -152,7 +152,12 @@ def _get_mm_fields_config( # Keep these as batched, as they always have batch size as first dim mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image") mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image") - mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") + # Scalar per-image patch counts — consumed only via `.tolist()` for + # Python-level split sizing; keep on CPU to avoid a needless H2D and + # the downstream D2H sync. + mm_fields["num_image_patches"] = MultiModalFieldConfig.batched( + "image", keep_on_cpu=True + ) return mm_fields def _get_hf_mm_data( @@ -387,6 +392,13 @@ def embed_multimodal(self, **kwargs): kwargs.pop("mm_token_type_ids", None) # used only in `model.get_rope_index` if pixel_values is not None: + # The underlying HuggingFace `get_image_features` implementations + # contain model-internal syncs (e.g. Idefics3 filters all-zero + # padding images via boolean-mask indexing, LlavaOnevision + # branches on per-sample batch counts). These are third-party and + # not something we can refactor here. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + # ROCm: Force math SDP backend for vision encoder to avoid accuracy issues # with flash_sdp and mem_efficient_sdp if current_platform.is_rocm(): @@ -398,16 +410,20 @@ def embed_multimodal(self, **kwargs): "`mem_efficient_sdp` backends. See issue: " "https://github.com/vllm-project/vllm/issues/30167" ) - with torch.nn.attention.sdpa_kernel( - backends=[torch.nn.attention.SDPBackend.MATH] + with ( + torch.nn.attention.sdpa_kernel( + backends=[torch.nn.attention.SDPBackend.MATH] + ), + gpu_sync_allowed(), ): vision_embeddings = self.model.get_image_features( pixel_values, **kwargs ) else: - vision_embeddings = self.model.get_image_features( - pixel_values, **kwargs - ) + with gpu_sync_allowed(): + vision_embeddings = self.model.get_image_features( + pixel_values, **kwargs + ) # Transformers `v5`, `self.get_image_features` returns a tuple # containing the features and optionally attentions/hidden_states diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 83241b329da3..f3f8b48a04ff 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -45,6 +45,7 @@ from vllm.renderers import TokenizeParams from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( @@ -692,15 +693,23 @@ def _process_audio_input( embeddings.shape[0], -1 ) mask = indices < audio_token_len[:, None] - # Apply mask and flatten - flattened_embeddings = embeddings[mask] - - # Return one tensor per input audio - embed_lens = [ - chunk_lens.sum().item() - for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist()) - ] - return flattened_embeddings.split(embed_lens) + # The boolean-mask gather, the `.item()` on per-group lengths, and the + # `.tolist()` feeding `torch.split` all force GPU->CPU syncs (output + # sizes are data-dependent). Suppress the sync check here; avoiding + # these syncs would require restructuring the downstream contract + # for per-audio flattened embeddings. + with gpu_sync_allowed(): + # Apply mask and flatten + flattened_embeddings = embeddings[mask] + + # Return one tensor per input audio + embed_lens = [ + chunk_lens.sum().item() + for chunk_lens in audio_token_len.split( + audio_input["num_chunks"].tolist() + ) + ] + return flattened_embeddings.split(embed_lens) def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index a6c46339303e..19c2b4ee09f2 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -30,10 +30,8 @@ from vllm.multimodal import NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv -from vllm.utils.platform_utils import ( - is_pin_memory_available, -) from vllm.utils.torch_utils import ( + async_tensor_h2d, direct_register_custom_op, ) @@ -498,10 +496,9 @@ def isin_list( elements: torch.Tensor, test_elements_list: list[int], ) -> torch.Tensor: - test_elements = torch.tensor( - test_elements_list, - pin_memory=is_pin_memory_available(), - ).to(device=elements.device, non_blocking=True) + test_elements = async_tensor_h2d( + test_elements_list, dtype=torch.int64, device=elements.device + ) return torch.isin(elements, test_elements) @@ -772,14 +769,13 @@ def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int: return layer_index -def cast_overflow_tensors( - tensors: torch.Tensor, - offset: float = 1000, -) -> torch.Tensor: - if tensors.isinf().any() or tensors.isnan().any(): - clamp_value = torch.finfo(tensors.dtype).max - offset - tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) - return tensors +def cast_overflow_tensors(tensors: torch.Tensor, offset: float = 1000) -> torch.Tensor: + # Always clamp rather than guarding with `.isinf().any()`/`.isnan().any()`, + # which would force a GPU->CPU sync on the hot path. Clamp is a no-op for + # in-range values and turns +/-inf into the finite min/max; NaN passes + # through either way, matching the previous guard's behavior. + clamp_value = torch.finfo(tensors.dtype).max - offset + return torch.clamp(tensors, min=-clamp_value, max=clamp_value) def fast_topk( diff --git a/vllm/multimodal/evs.py b/vllm/multimodal/evs.py index 62611c89719a..0f1c5bec71ae 100644 --- a/vllm/multimodal/evs.py +++ b/vllm/multimodal/evs.py @@ -85,7 +85,9 @@ def compute_retention_mask( topk_indices = order[:retain_num_tokens] retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool) - retention_mask[topk_indices] = True + # Use `index_fill_` instead of advanced-index scalar assign to avoid the + # synchronizing `aten::index_put_` path on CUDA. + retention_mask.index_fill_(0, topk_indices, True) retention_mask = retention_mask.reshape(dissimilarity.size()) mask = retention_mask.view(-1) # "T H W -> (T H W)" diff --git a/vllm/utils/gpu_sync_debug.py b/vllm/utils/gpu_sync_debug.py new file mode 100644 index 000000000000..aebe91f65b06 --- /dev/null +++ b/vllm/utils/gpu_sync_debug.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +import sys +from contextlib import contextmanager + +import torch + +import vllm.envs as envs +from vllm.platforms import current_platform + +_GPU_SYNC_ALLOWED_FIRST_SEEN: set[tuple[str, int]] = set() + +# Global sync-check gate. Off during engine setup (model load, KV cache +# init, warmup/compile) so first-compile and lazy-init syncs pass through; +# flipped on by `enable_gpu_sync_check()` at the end of +# `GPUWorker.compile_or_warm_up_model`, after which `with_gpu_sync_check`- +# decorated functions activate the configured debug mode. +_sync_check_enabled: bool = False + + +def enable_gpu_sync_check() -> None: + """Flip the sync-check gate on. Call once per worker, after warmup / + first-compile is complete.""" + global _sync_check_enabled + _sync_check_enabled = True + _install_compile_time_sync_suppressors() + + +_compile_time_suppressors_installed: bool = False + + +def _install_compile_time_sync_suppressors() -> None: + """Wrap torch inductor/aot_autograd compile entry points so the + synchronizing ops those passes perform (e.g. `constant_fold_uniform_value` + calling `.item()` on uniform-valued constants) don't trip the + sync-check mode we set around `execute_model` / `sample_tokens`. + + Warmup-time compiles already run under the gate (before + `enable_gpu_sync_check`), but post-warmup compiles (runtime + recompiles from dynamic shape variants, pipeline-parallel fresh + compile cache, etc.) fire inside `execute_model`. We intentionally + only want to flag *model-execution* syncs — compile-time work is + third-party and unavoidable. + """ + global _compile_time_suppressors_installed + if _compile_time_suppressors_installed: + return + _compile_time_suppressors_installed = True + + try: # noqa: BLE001 + from torch._inductor.fx_passes import joint_graph as _jg + + _orig_joint = _jg.joint_graph_passes + + @functools.wraps(_orig_joint) + def _wrapped_joint(*args, **kwargs): + prev_mode = torch.cuda.get_sync_debug_mode() + if not prev_mode: + return _orig_joint(*args, **kwargs) + torch.cuda.set_sync_debug_mode(0) + try: + return _orig_joint(*args, **kwargs) + finally: + torch.cuda.set_sync_debug_mode(prev_mode) + + # `compile_fx` does `from .fx_passes.joint_graph import + # joint_graph_passes`, which binds the *function object* at import + # time. Patching just the module attribute won't update that rebind, + # so patch every already-imported reference we can find. Restrict + # the scan to torch's compile-time modules — iterating all of + # `sys.modules` triggers `__getattr__` shims on third-party packages + # (e.g. transformers image_processing modules emit a deprecation + # warning on every attribute access). + import sys as _sys + + setattr(_jg, "joint_graph_passes", _wrapped_joint) # noqa: B010 + for _name, _mod in list(_sys.modules.items()): + if _mod is None: + continue + if not ( + _name.startswith("torch._inductor") + or _name.startswith("torch._functorch") + or _name.startswith("torch._dynamo") + ): + continue + if getattr(_mod, "joint_graph_passes", None) is _orig_joint: + setattr(_mod, "joint_graph_passes", _wrapped_joint) # noqa: B010 + except Exception: # pragma: no cover + pass + + +@contextmanager +def _suppress_gpu_sync_check(prev_mode: int): + torch.cuda.set_sync_debug_mode(0) + try: + yield + finally: + torch.cuda.set_sync_debug_mode(prev_mode) + + +@contextmanager +def _noop_cm(): + yield + + +if current_platform.is_cuda_alike(): + + def gpu_sync_allowed(first_only: bool = False): + """Context manager that suppresses `torch.cuda.set_sync_debug_mode` for the + duration of the `with` block. + + If `first_only` is True, only the first entry from this call site + suppresses the sync check; subsequent entries from the same site are + no-ops so any further GPU syncs will be reported. The "site" is the + caller's (filename, lineno), so different + `with gpu_sync_allowed(first_only=True):` lines track independently. + """ + if torch.compiler.is_compiling(): + return _noop_cm() + prev_mode = torch.cuda.get_sync_debug_mode() + if not prev_mode: + return _noop_cm() + if first_only: + frame = sys._getframe(1) + key = (frame.f_code.co_filename, frame.f_lineno) + if key in _GPU_SYNC_ALLOWED_FIRST_SEEN: + return _noop_cm() + _GPU_SYNC_ALLOWED_FIRST_SEEN.add(key) + return _suppress_gpu_sync_check(prev_mode) + + def with_gpu_sync_check(fn): + """Decorator that enables `torch.cuda.set_sync_debug_mode` around `fn` + when `VLLM_GPU_SYNC_CHECK` is set *and* the gate has been flipped by + `enable_gpu_sync_check()`. Before the gate flips (i.e. during + engine setup / warmup) the decorated function runs as-is. + + The env var is parsed once at decoration time; this module is imported + lazily after `VllmConfig.__post_init__` has finalized `VLLM_GPU_SYNC_CHECK`. + """ + mode = envs.VLLM_GPU_SYNC_CHECK + if mode is None: + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not _sync_check_enabled: + return fn(*args, **kwargs) + prev_mode = torch.cuda.get_sync_debug_mode() + torch.cuda.set_sync_debug_mode(mode) + try: + return fn(*args, **kwargs) + finally: + torch.cuda.set_sync_debug_mode(prev_mode) + + return wrapper + +else: + + def gpu_sync_allowed(first_only: bool = False): + return _noop_cm() + + def with_gpu_sync_check(fn): + return fn diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 1eb9306ed4b1..8cbc01261fa6 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -3,6 +3,7 @@ import contextlib import importlib.metadata import os +import platform import random import threading from collections.abc import Callable, Collection @@ -67,6 +68,11 @@ T = TypeVar("T") +# Pin memory in non-WSL case. +# Logic duplicated here for now to avoid circular import. +PIN_MEMORY = "microsoft" not in " ".join(platform.uname()).lower() + + def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: return ( kv_cache_dtype.startswith("fp8") @@ -576,12 +582,12 @@ def create_kv_caches_with_random( def async_tensor_h2d( data: list, dtype: torch.dtype, - target_device: str | torch.device, - pin_memory: bool, + device: str | torch.device, + pin_memory: bool = PIN_MEMORY, ) -> torch.Tensor: """Asynchronously create a tensor and copy it from host to device.""" t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") - return t.to(device=target_device, non_blocking=True) + return t.to(device=device, non_blocking=True) def make_ndarray_with_pad( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 662ead1d1d01..34a362441a0e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1093,7 +1093,8 @@ def build( paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[ prefill_start : num_reqs + 1 ] - paged_kv_indptr_prefill_gpu[0] = 0 + # Assign to slice to avoid cpu sync. + paged_kv_indptr_prefill_gpu[:1] = 0 torch.cumsum( num_blocks_per_req, dim=0, diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index b70902478e8f..c7475a026220 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -47,12 +47,16 @@ flex_attention_compiled = torch.compile(flex_attention, fullgraph=True) -def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: - device = offsets.device - counts = offsets[1:] - offsets[:-1] - return torch.repeat_interleave( - torch.arange(len(counts), device=device, dtype=torch.int32), counts +def _offsets_to_doc_ids_tensor( + offsets_cpu: torch.Tensor, device: torch.device +) -> torch.Tensor: + # Build on CPU (so `repeat_interleave` doesn't force a GPU->CPU sync to + # learn the data-dependent output length) and upload non-blocking. + counts = offsets_cpu[1:] - offsets_cpu[:-1] + doc_ids = torch.repeat_interleave( + torch.arange(len(counts), dtype=torch.int32), counts ) + return doc_ids.to(device, non_blocking=True) def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): @@ -285,11 +289,13 @@ def unique_static_unsorted( keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)) # [B, N] # ── left-pack uniques into a fresh tensor ─────────────────────────── + # Route non-kept entries to a garbage slot at column N so we can do a + # single scatter rather than using torch.nonzero (which would force a + # GPU->CPU sync to enumerate kept positions). dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go - packed_flat = torch.full_like(x_flat, pad_val) - - rows, src_cols = torch.nonzero(keep, as_tuple=True) - packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols] + dest_pos = torch.where(keep, dest_pos, N) + packed_extended = torch.full((B, N + 1), pad_val, device=device, dtype=x_flat.dtype) + packed_flat = packed_extended.scatter_(1, dest_pos, x_flat)[:, :N] # ── restore original layout ───────────────────────────────────────── packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim) @@ -341,6 +347,9 @@ class FlexAttentionMetadata: num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor + # CPU-resident copy of query_start_loc used to derive doc_ids without a + # GPU->CPU sync from repeat_interleave's data-dependent output size. + query_start_loc_cpu: torch.Tensor max_seq_len: int seq_lens: torch.Tensor block_table: torch.Tensor @@ -447,12 +456,7 @@ def final_mask_mod( (is_valid, logical_q_idx, logical_kv_idx) = ( self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) ) - # Apply mask modification only for valid indices - return torch.where( - is_valid, - self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx), - False, - ) + return is_valid & self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx) return final_mask_mod @@ -464,7 +468,9 @@ def get_bidirectional_mask_mod(self) -> _mask_mod_signature: packed query sequences. """ # Create a lookup mapping from query indices -> request number - request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + request_lookup = _offsets_to_doc_ids_tensor( + self.query_start_loc_cpu, self.query_start_loc.device + ) def final_mask_mod( b: torch.Tensor, @@ -576,7 +582,9 @@ def get_transformed_score_mod(self) -> _score_mod_signature | None: return None # Create a lookup mapping from query indices -> request number - request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + request_lookup = _offsets_to_doc_ids_tensor( + self.query_start_loc_cpu, self.query_start_loc.device + ) user_score_mod = self.score_mod def transformed_score_mod( @@ -721,7 +729,9 @@ def __post_init__(self): assert self.prefix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet." # Create a lookup mapping from query indices -> request number - self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) + self.doc_ids = _offsets_to_doc_ids_tensor( + self.query_start_loc_cpu, self.query_start_loc.device + ) self.doc_ids = copy_to_persistent(self.persistent_doc_ids, self.doc_ids) self.num_blocks = self.total_cache_tokens // self.block_size @@ -802,6 +812,7 @@ def build( max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping @@ -866,6 +877,7 @@ def build( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table_tensor, diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index fa7d4bd2ec51..6826541271d5 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -7,6 +7,7 @@ import torch from vllm.config import VllmConfig +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.v1.attention.backend import ( AttentionBackend, CommonAttentionMetadata, @@ -147,11 +148,14 @@ def build( # Compute seq_idx for prefill only if common.num_prefills > 0: - prep_initial_states = ( - torch.any(common.has_initial_states_p).item() - if common.has_initial_states_p is not None - else False - ) + prep_initial_states = False + if common.has_initial_states_p is not None: + # TODO: avoid this sync by either always running the torch.where + # in mamba_mixer2.py (dropping the prep_initial_states gate), or + # plumbing a CPU-side num_computed_tokens once the deprecated + # CommonAttentionMetadata._num_computed_tokens_cpu migration lands. + with gpu_sync_allowed(): + prep_initial_states = torch.any(common.has_initial_states_p).item() cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = ( self._build_chunk_metadata_tensors( diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index eec53032288d..716dfcde592f 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -9,6 +9,7 @@ from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import async_tensor_h2d from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -270,16 +271,20 @@ def _build_chunk_metadata_tensors( num_prefills = common.num_prefills num_decode_tokens = common.num_decode_tokens - num_computed_tokens_cpu = ( - common_attn_metadata.compute_num_computed_tokens().cpu() - ) - num_computed_tokens_p_cpu = num_computed_tokens_cpu[ - num_reqs - num_prefills : num_reqs - ] + # Derive prefill context lengths from CPU data only. + # `seq_lens_cpu_upper_bound` is precise for prefill rows in all modes + # (including async spec decode), so this avoids the D2H sync that + # `compute_num_computed_tokens().cpu()` would force. + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound + assert seq_lens_cpu is not None query_start_loc_p_cpu = ( common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :] - num_decode_tokens ) + prefill_query_lens_cpu = query_start_loc_p_cpu[1:] - query_start_loc_p_cpu[:-1] + num_computed_tokens_p_cpu = ( + seq_lens_cpu[num_reqs - num_prefills : num_reqs] - prefill_query_lens_cpu + ) cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata( chunk_size, @@ -289,20 +294,14 @@ def _build_chunk_metadata_tensors( ) device = common_attn_metadata.query_start_loc.device - cu_chunk_seqlen_p = torch.as_tensor( - cu_chunk_seqlen, - device=device, - dtype=torch.int32, - ) - seq_idx_p = torch.as_tensor( - seq_idx, - device=device, - dtype=torch.int32, + # Build on pinned CPU and upload non-blocking to avoid the synchronous + # H2D copy that `torch.as_tensor(list, device=cuda)` would force. + cu_chunk_seqlen_p = async_tensor_h2d( + cu_chunk_seqlen, dtype=torch.int32, device=device ) - last_chunk_indices_p = torch.as_tensor( - last_chunk_indices, - device=device, - dtype=torch.int32, + seq_idx_p = async_tensor_h2d(seq_idx, dtype=torch.int32, device=device) + last_chunk_indices_p = async_tensor_h2d( + last_chunk_indices, dtype=torch.int32, device=device ) return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index ceee8d5499ea..af9c91d11ee2 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -90,6 +90,14 @@ class TreeAttentionMetadata: num_prefills: int = 0 num_decodes: int = 0 + # Precomputed (on CPU in the builder) max_query_len and max_seq_len for + # the prefill-only and decode-only sub-batches. Used by the properties + # below to avoid a GPU->CPU sync via `.max().item()` on every forward. + max_query_len_prefill: int = 0 + max_seq_len_prefill: int = 0 + max_query_len_decode: int = 0 + max_seq_len_decode: int = 0 + tree_attn_bias: torch.Tensor | None = None # Cached Prefill/decode metadata. @@ -107,14 +115,13 @@ def prefill_metadata(self) -> "TreeAttentionMetadata | None": return self._cached_prefill_metadata q_start_loc = self.query_start_loc[self.num_decodes :] - q_seqlens = torch.diff(q_start_loc) kv_seqlens = self.seq_lens[self.num_decodes :] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, - max_query_len=int(q_seqlens.max().item()), + max_query_len=self.max_query_len_prefill, query_start_loc=q_start_loc - q_start_loc[0], - max_seq_len=int(kv_seqlens.max().item()), + max_seq_len=self.max_seq_len_prefill, seq_lens=kv_seqlens, block_table=self.block_table[self.num_decodes :], slot_mapping=self.slot_mapping[self.num_decode_tokens :], @@ -132,14 +139,13 @@ def decode_metadata(self) -> "TreeAttentionMetadata | None": return self._cached_decode_metadata q_start_loc = self.query_start_loc[: self.num_decodes + 1] - q_seqlens = torch.diff(q_start_loc) kv_seqlens = self.seq_lens[: self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_decode_tokens, - max_query_len=int(q_seqlens.max().item()), + max_query_len=self.max_query_len_decode, query_start_loc=q_start_loc, - max_seq_len=int(kv_seqlens.max().item()), + max_seq_len=self.max_seq_len_decode, seq_lens=kv_seqlens, block_table=self.block_table[: self.num_decodes], slot_mapping=self.slot_mapping[: self.num_decode_tokens], @@ -199,6 +205,42 @@ def build( block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + # Precompute prefill/decode sub-batch max_query_len / max_seq_len on + # CPU so the prefill_metadata / decode_metadata properties don't need + # a GPU->CPU sync via `.max().item()` on every forward. + # Prefer `seq_lens_cpu_upper_bound` over the (deprecated) + # `seq_lens_cpu` property: the upper bound is precise for prefill + # rows and optimistic-but-safe for decode rows (workspace sizing + # from `max()` is fine with an over-estimate), and avoids the + # `seq_lens.to("cpu")` sync the property would fall through to in + # async-spec-decode mode. The draft-attention path (eagle + # speculator) doesn't populate it; fall back to the batch-wide + # `max_seq_len` as a safe upper bound for both sub-batches. + q_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound + if num_prefills > 0: + q_seqlens_p = torch.diff(q_start_loc_cpu[num_decodes:]) + max_query_len_prefill = int(q_seqlens_p.max()) + max_seq_len_prefill = ( + int(seq_lens_cpu[num_decodes:].max()) + if seq_lens_cpu is not None + else max_seq_len + ) + else: + max_query_len_prefill = 0 + max_seq_len_prefill = 0 + if num_decodes > 0: + q_seqlens_d = torch.diff(q_start_loc_cpu[: num_decodes + 1]) + max_query_len_decode = int(q_seqlens_d.max()) + max_seq_len_decode = ( + int(seq_lens_cpu[:num_decodes].max()) + if seq_lens_cpu is not None + else max_seq_len + ) + else: + max_query_len_decode = 0 + max_seq_len_decode = 0 + return TreeAttentionMetadata( num_actual_tokens=num_actual_tokens, num_prefill_tokens=num_prefill_tokens, @@ -211,6 +253,10 @@ def build( seq_lens=kv_seqlens, block_table=block_table, slot_mapping=slot_mapping, + max_query_len_prefill=max_query_len_prefill, + max_seq_len_prefill=max_seq_len_prefill, + max_query_len_decode=max_query_len_decode, + max_seq_len_decode=max_seq_len_decode, tree_attn_bias=self.tree_attn_bias, ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index f254d95a414c..0aeb33ae6fd9 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,7 +18,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import next_power_of_2 -from vllm.utils.torch_utils import is_quantized_kv_cache +from vllm.utils.torch_utils import async_tensor_h2d, is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -117,10 +117,9 @@ def compute_mm_prefix_range_tensor( for r in range_lists: padded_r = list(r) + [(0, 0)] * (max_ranges - len(r)) padded.append(padded_r) - # Create tensor with efficient H2D transfer - return torch.tensor(padded, dtype=torch.int32, device=device).view( - num_seqs, max_ranges, 2 - ) + # Build on pinned CPU memory so the H2D transfer is non-blocking. + padded = async_tensor_h2d(padded, dtype=torch.int32, device=device) + return padded.view(num_seqs, max_ranges, 2) class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index af2d0fb0830f..53684b4360f7 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -187,6 +187,10 @@ class TurboQuantMetadata(AttentionMetadata): is_prefill: bool = False num_decodes: int = 0 # number of decode requests (first in batch) num_decode_tokens: int = 0 # tokens from decode requests + # CPU-resident copies used by the prefill path for per-request iteration + # without per-step D2H syncs. + query_start_loc_cpu: torch.Tensor | None = None + seq_lens_cpu: torch.Tensor | None = None class TurboQuantMetadataBuilder(AttentionMetadataBuilder[TurboQuantMetadata]): @@ -230,6 +234,8 @@ def build(self, common_prefix_len, common_attn_metadata, fast_build=False): is_prefill=(cam.max_query_len > 1), num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, + query_start_loc_cpu=cam.query_start_loc_cpu, + seq_lens_cpu=cam.seq_lens_cpu_upper_bound, ) @@ -474,11 +480,21 @@ def forward( # first-chunk prefills. Using full-batch max_seq_len breaks # this because decode requests inflate max_seq_len. prefill_seq_lens = attn_metadata.seq_lens[num_decodes:] - # Use CPU-side max to avoid GPU→CPU sync from .item() - prefill_max_seq = max(attn_metadata.seq_lens[num_decodes:].tolist()) + # Use the CPU-resident `seq_lens` upper-bound from the metadata + # (populated in the builder) to compute the prefill sub-batch + # max without a GPU→CPU sync. + if attn_metadata.seq_lens_cpu is not None: + prefill_max_seq = int(attn_metadata.seq_lens_cpu[num_decodes:].max()) + else: + prefill_max_seq = attn_metadata.max_seq_len prefill_qsl = ( attn_metadata.query_start_loc[num_decodes:] - num_decode_tokens ) + prefill_qsl_cpu = None + if attn_metadata.query_start_loc_cpu is not None: + prefill_qsl_cpu = ( + attn_metadata.query_start_loc_cpu[num_decodes:] - num_decode_tokens + ) prefill_meta = TurboQuantMetadata( seq_lens=prefill_seq_lens, slot_mapping=attn_metadata.slot_mapping[num_decode_tokens:N], @@ -488,6 +504,10 @@ def forward( max_query_len=attn_metadata.max_query_len, max_seq_len=prefill_max_seq, is_prefill=True, + query_start_loc_cpu=prefill_qsl_cpu, + seq_lens_cpu=attn_metadata.seq_lens_cpu[num_decodes:] + if attn_metadata.seq_lens_cpu is not None + else None, ) k = key[:N].view(N, self.num_kv_heads, self.head_size) v = value[:N].view(N, self.num_kv_heads, self.head_size) @@ -578,10 +598,16 @@ def _prefill_attention( output = torch.zeros(N, Hq, D, device=query.device, dtype=query.dtype) - # Convert to Python lists once (single CPU-GPU sync) instead of - # per-request .item() calls that each force a sync. - qsl = query_start_loc.tolist() - seq_lens_list = attn_metadata.seq_lens.tolist() + # Prefer the CPU-resident copies from the metadata if populated — + # otherwise `.tolist()` on GPU tensors forces a synchronizing copy. + if attn_metadata.query_start_loc_cpu is not None: + qsl = attn_metadata.query_start_loc_cpu.tolist() + else: + qsl = query_start_loc.tolist() + if attn_metadata.seq_lens_cpu is not None: + seq_lens_list = attn_metadata.seq_lens_cpu.tolist() + else: + seq_lens_list = attn_metadata.seq_lens.tolist() # Pre-allocate cu_seqlens for single-request flash_attn calls # to avoid per-request host→device tensor creation. @@ -612,7 +638,8 @@ def _prefill_attention( if q_len == seq_len: # First-chunk prefill: all K/V are in the current batch. if _HAS_FLASH_ATTN: - self._cu_2[1] = q_len + # Assign to slice to avoid gpu/cpu sync. + self._cu_2[1:2] = q_len cu = self._cu_2 out = self._flash_attn_varlen( q=q_seq, @@ -791,8 +818,9 @@ def _continuation_prefill( if not hasattr(self, "_cu_2_q"): self._cu_2_q = torch.zeros(2, device=device, dtype=torch.int32) self._cu_2_k = torch.zeros(2, device=device, dtype=torch.int32) - self._cu_2_q[1] = q_len - self._cu_2_k[1] = seq_len + # Assigning to slice uses fill_ which avoids cpu/gpu sync. + self._cu_2_q[1:2] = q_len + self._cu_2_k[1:2] = seq_len cu_seqlens_q = self._cu_2_q cu_seqlens_k = self._cu_2_k return self._flash_attn_varlen( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 54ebd088b95e..1304e41da636 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -332,8 +332,10 @@ def make_local_attention_virtual_batches( # regression when using numpy arrays (batch and block indices) to index into # torch tensor (block_table). As a workaround, convert numpy arrays to torch # tensor first, which recovers perf. - batch_indices_torch = torch.from_numpy(batch_indices) - block_indices_torch = torch.from_numpy(block_indices) + # Upload the index tensors to the block_table's device up-front so that the + # fancy indexing below doesn't implicitly force a synchronous H2D copy. + batch_indices_torch = torch.from_numpy(batch_indices).to(device, non_blocking=True) + block_indices_torch = torch.from_numpy(block_indices).to(device, non_blocking=True) # Save as a lambda so we can return this for update_block_table make_block_table = lambda block_table: block_table[ @@ -391,7 +393,16 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( # Figure out how many tokens are in each request # num_decode_tokens: [1, 2, 1] - num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + # Avoid `torch.bincount` here — on CUDA it forces a sync to determine + # the output size (even with `minlength`, the kernel must confirm no + # value exceeds the bound). `scatter_add_` into a preallocated buffer + # is equivalent and stays async. + num_decode_tokens = torch.zeros( + num_reqs, dtype=request_ids.dtype, device=request_ids.device + ) + num_decode_tokens.scatter_add_( + 0, request_ids.to(num_decode_tokens.dtype), torch.ones_like(request_ids) + ) # Calculate new query_start_loc with tokens in generation_indices # decode_query_start_loc: [0, 1, 3, 4] @@ -399,10 +410,15 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype ) - decode_query_start_loc[0] = 0 + decode_query_start_loc[:1].fill_(0) # Avoid sync from scalar assignment. decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) - decode_max_query_len = int(num_decode_tokens.max().item()) - total_num_decode_tokens = int(num_decode_tokens.sum().item()) + # `.item()` reductions here are unavoidable — the CommonAttentionMetadata + # fields below need Python ints. Feature is opt-in (kv_sharing_fast_prefill). + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + + with gpu_sync_allowed(): + decode_max_query_len = int(num_decode_tokens.max().item()) + total_num_decode_tokens = int(num_decode_tokens.sum().item()) common_attn_metadata = CommonAttentionMetadata( query_start_loc=decode_query_start_loc, diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index 4506f452cf9a..165a3fcc94e4 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -17,6 +17,7 @@ import torch.nn.functional as F from vllm.platforms import current_platform +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.utils.torch_utils import direct_register_custom_op @@ -45,7 +46,13 @@ def flash_attn_maxseqlen_wrapper( cu_seqlens = torch.arange( 0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device ) - max_seqlen = q_len if max_seqlen is None else max_seqlen.item() + if max_seqlen is None: + max_seqlen = q_len + else: + # `flash_attn_varlen_func` needs a Python int for kernel launch + # bounds; the D2H is unavoidable here. + with gpu_sync_allowed(): + max_seqlen = max_seqlen.item() q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( @@ -126,7 +133,12 @@ def triton_attn_wrapper( cu_seqlens = torch.arange( 0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device ) - max_seqlen = q_len if max_seqlen is None else max_seqlen.item() + if max_seqlen is None: + max_seqlen = q_len + else: + # `context_attention_fwd` needs a Python int; the D2H is unavoidable. + with gpu_sync_allowed(): + max_seqlen = max_seqlen.item() q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = torch.empty_like(q) @@ -228,7 +240,9 @@ def torch_sdpa_wrapper( outputs = [] - lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + # `torch.split` needs Python int sizes; the D2H is unavoidable here. + with gpu_sync_allowed(): + lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() q_chunks = torch.split(q, lens, dim=1) k_chunks = torch.split(k, lens, dim=1) v_chunks = torch.split(v, lens, dim=1) @@ -304,7 +318,10 @@ def flashinfer_wrapper( batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1) batch_offsets_v = cu_seqlens[cu_seqlength:].view(-1, 1, 1, 1) sequence_lengths = sequence_lengths.view(-1, 1, 1, 1) - max_seqlen = max_seqlen.item() + # `cudnn_batch_prefill_with_kv_cache` needs Python ints for the + # max-token-per-seq bounds; the D2H is unavoidable here. + with gpu_sync_allowed(): + max_seqlen = max_seqlen.item() output, _ = cudnn_batch_prefill_with_kv_cache( q, diff --git a/vllm/v1/sample/ops/bad_words.py b/vllm/v1/sample/ops/bad_words.py index 56972e517980..e253abce2e9f 100644 --- a/vllm/v1/sample/ops/bad_words.py +++ b/vllm/v1/sample/ops/bad_words.py @@ -23,7 +23,8 @@ def _apply_bad_words_single_batch( assert len(actual_prefix) == len(expected_prefix) if actual_prefix == expected_prefix: - logits[last_token_id] = _SMALLEST_LOGIT + # assign to slice to avoid cpu->gpu sync + logits[last_token_id : last_token_id + 1] = _SMALLEST_LOGIT def apply_bad_words( diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 4c7c3e99d44b..0f8749b3cf5e 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -12,6 +12,7 @@ import torch from vllm.triton_utils import tl, triton +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.utils.math_utils import next_power_of_2 from vllm.utils.platform_utils import num_compute_units @@ -1023,12 +1024,13 @@ def apply_top_k_top_p_triton( # Cache lookup table entries on each device. tables = _TRITON_TABLE_CACHE.get(logits.device) if tables is None: - normal_cdf_to_sigma_table = logits.new_tensor(_NORMAL_CDF_TO_SIGMA_TABLE) - percentile_to_std_table = logits.new_tensor(_PERCENTILE_TO_STD_TABLE) - _TRITON_TABLE_CACHE[logits.device] = ( - normal_cdf_to_sigma_table, - percentile_to_std_table, - ) + with gpu_sync_allowed(): + normal_cdf_to_sigma_table = logits.new_tensor(_NORMAL_CDF_TO_SIGMA_TABLE) + percentile_to_std_table = logits.new_tensor(_PERCENTILE_TO_STD_TABLE) + _TRITON_TABLE_CACHE[logits.device] = ( + normal_cdf_to_sigma_table, + percentile_to_std_table, + ) else: normal_cdf_to_sigma_table, percentile_to_std_table = tables diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index a77eafba2556..73f0047a2783 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -6,6 +6,7 @@ import torch.nn as nn from vllm.config.model import LogprobsMode +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -176,26 +177,28 @@ def gather_specific_token_logprobs( # Find max number of tokens across all requests max_num_tokens = max(len(tids) for tids in logprob_token_ids.values()) - # Create padded token_ids tensor: [batch_size, max_num_tokens + 1] - # +1 for sampled token in first position - token_ids_tensor = torch.zeros( - batch_size, max_num_tokens + 1, dtype=torch.int64, device=device + # Build the padded token_ids and valid_mask matrices on pinned CPU, + # then upload non-blocking. + token_ids_cpu = torch.zeros( + batch_size, max_num_tokens + 1, dtype=torch.int64, pin_memory=True ) - token_ids_tensor[:, 0] = sampled # First column is sampled token - # Create mask for valid positions (True = valid, False = padded) - valid_mask = torch.zeros( - batch_size, max_num_tokens + 1, dtype=torch.bool, device=device + valid_mask_cpu = torch.zeros( + batch_size, max_num_tokens + 1, dtype=torch.bool, pin_memory=True ) - valid_mask[:, 0] = True # Sampled token is always valid - - # Fill in token IDs for each request + valid_mask_cpu[:, 0] = True # Sampled token is always valid for req_idx, token_ids in logprob_token_ids.items(): num_tokens = len(token_ids) - token_ids_tensor[req_idx, 1 : num_tokens + 1] = torch.tensor( - token_ids, dtype=torch.int64, device=device + token_ids_cpu[req_idx, 1 : num_tokens + 1] = torch.as_tensor( + token_ids, dtype=torch.int64 ) - valid_mask[req_idx, 1 : num_tokens + 1] = True + valid_mask_cpu[req_idx, 1 : num_tokens + 1] = True + + token_ids_tensor = token_ids_cpu.to(device, non_blocking=True) + valid_mask = valid_mask_cpu.to(device, non_blocking=True) + # Sampled token in column 0 — fill on-device from the sampled GPU + # tensor so we don't need to D2H + re-upload. + token_ids_tensor[:, 0] = sampled # Compute logprobs using the fused Triton kernel (log_softmax + gather) logprobs = compute_token_logprobs(logits, token_ids_tensor) @@ -328,9 +331,10 @@ def gather_logprobs( # of the compiled batched_count_greater_than. mark_unbacked makes # the size fully symbolic so dynamo doesn't specialize when # batch_size transitions from 1 to >=2. - torch._dynamo.decorators.mark_unbacked(logprobs, 0) - torch._dynamo.decorators.mark_unbacked(token_logprobs, 0) - token_ranks = batched_count_greater_than(logprobs, token_logprobs) + with gpu_sync_allowed(first_only=True): + torch._dynamo.decorators.mark_unbacked(logprobs, 0) + torch._dynamo.decorators.mark_unbacked(token_logprobs, 0) + token_ranks = batched_count_greater_than(logprobs, token_logprobs) # Concatenate together with the topk. indices = torch.cat((token_ids, topk_indices), dim=1) diff --git a/vllm/v1/sample/thinking_budget_state.py b/vllm/v1/sample/thinking_budget_state.py index 74599a1e8c55..303c9f586801 100644 --- a/vllm/v1/sample/thinking_budget_state.py +++ b/vllm/v1/sample/thinking_budget_state.py @@ -6,6 +6,7 @@ import torch +from vllm.utils.torch_utils import async_tensor_h2d from vllm.v1.sample.logits_processor.interface import ( BatchUpdate, MoveDirectionality, @@ -65,22 +66,9 @@ def __init__( self.cu_num_tokens: dict[int, int] = {} if self.num_spec_tokens > 0: - self.mask = torch.zeros( - max_num_reqs * (self.num_spec_tokens + 1), - dtype=torch.bool, - device=device, - ) - self.force_token_ids = torch.full( - (max_num_reqs * (self.num_spec_tokens + 1),), - -1, - dtype=torch.long, - device=device, - ) + self._mask_capacity = max_num_reqs * (self.num_spec_tokens + 1) else: - self.mask = torch.zeros(max_num_reqs, dtype=torch.bool, device=device) - self.force_token_ids = torch.full( - (max_num_reqs,), -1, dtype=torch.long, device=device - ) + self._mask_capacity = max_num_reqs def has_tracked_requests(self) -> bool: """True when ``sync_batch`` has state for a ``thinking_token_budget`` row. @@ -454,7 +442,6 @@ def _apply_forcing_to_logits( predict_bonus_token: bool, spec_token_ids_for_layout: list[list[int]], ) -> torch.Tensor: - self.mask[:] = False cumulative_total = 0 self.cu_num_tokens.clear() @@ -474,6 +461,12 @@ def _apply_forcing_to_logits( else: cumulative_total += 1 + # Build the active index / forced-token lists entirely on CPU so we + # avoid per-iteration scalar writes to GPU tensors (each of which + # forces a synchronizing H2D copy). + active_indices_cpu: list[int] = [] + force_tokens_cpu: list[int] = [] + for seq_idx in sorted(self._state.keys()): if seq_idx not in self.cu_num_tokens: continue @@ -502,9 +495,12 @@ def _apply_forcing_to_logits( for force_idx in force_index: if end_count < len(self.think_end_token_ids): mask_idx = self.cu_num_tokens[seq_idx] + force_idx - if mask_idx < len(self.mask) and mask_idx < logits.shape[0]: - self.mask[mask_idx] = True - self.force_token_ids[mask_idx] = ( + if ( + mask_idx < self._mask_capacity + and mask_idx < logits.shape[0] + ): + active_indices_cpu.append(mask_idx) + force_tokens_cpu.append( self.think_end_token_ids[end_count] ) if predict_bonus_token: @@ -514,15 +510,18 @@ def _apply_forcing_to_logits( else: state["bonus_token_forced"] = True - has_active_thinking = any( - state.get("in_end", False) for state in self._state.values() - ) - - if has_active_thinking: - active_indices = self.mask.nonzero(as_tuple=False).view(-1) - - if len(active_indices) > 0: - force_tokens = self.force_token_ids[active_indices] - logits[active_indices, force_tokens] = 1e9 + if active_indices_cpu: + device = logits.device + active_indices = async_tensor_h2d( + active_indices_cpu, dtype=torch.long, device=device + ) + force_tokens = async_tensor_h2d( + force_tokens_cpu, dtype=torch.long, device=device + ) + # Build the fill tensor on-device via `torch.full` (scalar-fill + # kernel, no H2D) so `index_put_` has a tensor value and won't + # need to upload a CPU scalar. + fill = logits.new_full((len(active_indices_cpu),), 1e9) + logits.index_put_((active_indices, force_tokens), fill) return logits diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py index eb24a9c933e2..5640db4848c7 100644 --- a/vllm/v1/spec_decode/ngram_proposer_gpu.py +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -18,6 +18,7 @@ VllmConfig, ) from vllm.forward_context import set_forward_context +from vllm.utils.torch_utils import async_tensor_h2d from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -569,8 +570,10 @@ def update_ngram_gpu_tensors_incremental( reorder_dst.append(curr_idx) if reorder_src: - src_tensor = torch.tensor(reorder_src, dtype=torch.long, device=device) - dst_tensor = torch.tensor(reorder_dst, dtype=torch.long, device=device) + # Pinned CPU + non_blocking H2D avoids the synchronous copy that + # `torch.tensor(list, device=cuda)` would otherwise force. + src_tensor = async_tensor_h2d(reorder_src, dtype=torch.long, device=device) + dst_tensor = async_tensor_h2d(reorder_dst, dtype=torch.long, device=device) temp_token_ids = token_ids_gpu_tensor[src_tensor].clone() temp_num_tokens = num_tokens_no_spec_gpu[src_tensor].clone() diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 64f4e59031e6..d1b1d9e7dd56 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -7,6 +7,7 @@ from vllm.config import ParallelConfig from vllm.distributed.parallel_state import get_dp_group from vllm.logger import init_logger +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.v1.worker.ubatch_utils import ( check_ubatch_thresholds, is_last_ubatch_empty, @@ -43,11 +44,13 @@ def _run_ar( dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank device, group = _get_device_and_group(parallel_config) - tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32) - tensor[0][dp_rank] = orig_num_tokens_per_ubatch - tensor[1][dp_rank] = padded_num_tokens_per_ubatch - tensor[2][dp_rank] = 1 if should_ubatch else 0 - tensor[3][dp_rank] = cudagraph_mode + # Populate this rank's contribution on CPU to reduce GPU syncs. + tensor_cpu = torch.zeros(4, dp_size, dtype=torch.int32) + tensor_cpu[0][dp_rank] = orig_num_tokens_per_ubatch + tensor_cpu[1][dp_rank] = padded_num_tokens_per_ubatch + tensor_cpu[2][dp_rank] = 1 if should_ubatch else 0 + tensor_cpu[3][dp_rank] = cudagraph_mode + tensor = tensor_cpu.to(device, non_blocking=True) dist.all_reduce(tensor, group=group) return tensor @@ -134,27 +137,29 @@ def _synchronize_dp_ranks( parallel_config=parallel_config, ) - # Synchronize cudagraph_mode across ranks first (take min). - # This is needed before DP padding decision since we use the synced - # cudagraph mode to determine whether DP padding is needed. - synced_cudagraph_mode = _post_process_cudagraph_mode(tensor) - - # Check conditions for microbatching - should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches) - - # DP padding is needed when cudagraph is enabled (synced across ranks) - # or when ubatching/DBO is active (ubatching requires uniform batch - # sizes across DP ranks currently). - # Use the synced runtime cudagraph mode rather than the compilation config - # so we can avoid padding when cudagraph is not enabled for this step. - should_dp_pad = synced_cudagraph_mode != 0 or should_ubatch - - # Pad all DP ranks up to the maximum token count across ranks if - # should_dp_pad is True - num_tokens_after_padding = _post_process_dp_padding( - tensor, - should_dp_pad, - ) + # The post-all-reduce reads (`.item()` / `.cpu()`) are inherently + # GPU->CPU syncs: the values drive Python-level control flow for + # ubatching, DP padding, and cudagraph-mode selection. + with gpu_sync_allowed(): + # Synchronize cudagraph_mode across ranks first (take min). + # This is needed before DP padding decision since we use the synced + # cudagraph mode to determine whether DP padding is needed. + synced_cudagraph_mode = _post_process_cudagraph_mode(tensor) + + # Check conditions for microbatching + should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches) + + # DP padding is needed when cudagraph is enabled (synced across ranks) + # or when ubatching/DBO is active (ubatching requires uniform batch + # sizes across DP ranks currently). + # Use the synced runtime cudagraph mode rather than the compilation + # config so we can avoid padding when cudagraph is not enabled for + # this step. + should_dp_pad = synced_cudagraph_mode != 0 or should_ubatch + + # Pad all DP ranks up to the maximum token count across ranks if + # should_dp_pad is True + num_tokens_after_padding = _post_process_dp_padding(tensor, should_dp_pad) return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode diff --git a/vllm/v1/worker/gpu/buffer_utils.py b/vllm/v1/worker/gpu/buffer_utils.py index a653c262556c..5963790a7792 100644 --- a/vllm/v1/worker/gpu/buffer_utils.py +++ b/vllm/v1/worker/gpu/buffer_utils.py @@ -167,7 +167,7 @@ def apply_write(self) -> None: # Special handling for write_contents write_contents = async_tensor_h2d( - self._staged_write_contents, self.dtype, self.device, pin_memory=True + self._staged_write_contents, self.dtype, self.device ) # Write diffs to the GPU buffer diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index dfe50cb135d0..2cc9226a2df7 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -43,6 +43,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput @@ -1271,7 +1272,9 @@ def sample_tokens( if self.use_async_scheduling: return async_output - return async_output.get_output() + + with gpu_sync_allowed(): + return async_output.get_output() def take_draft_token_ids(self) -> DraftTokenIds | None: return self.draft_tokens_handler.get_draft_tokens() diff --git a/vllm/v1/worker/gpu/sample/penalties.py b/vllm/v1/worker/gpu/sample/penalties.py index 04adf9369233..68b7f1b89b2c 100644 --- a/vllm/v1/worker/gpu/sample/penalties.py +++ b/vllm/v1/worker/gpu/sample/penalties.py @@ -58,8 +58,7 @@ def apply_staged_writes(self) -> None: idx_mapping = async_tensor_h2d( self._new_penalties_reqs, dtype=torch.int32, - target_device=self.device, - pin_memory=True, + device=self.device, ) prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs] @@ -284,8 +283,13 @@ def bincount( output_bin_counts: torch.Tensor, max_prefill_len: int, ) -> None: - prompt_bin_mask[expanded_idx_mapping] = 0 - output_bin_counts[expanded_idx_mapping] = 0 + # Use index_fill_ (which needs int64 indices) instead of + # `tensor[idx] = 0`: advanced-indexing scalar assignment lowers to + # aten::index_put_ on CUDA, which forces a host sync even with a + # GPU-resident index tensor. index_fill_ has no such sync. + idx_long = expanded_idx_mapping.long() + prompt_bin_mask.index_fill_(0, idx_long, 0) + output_bin_counts.index_fill_(0, idx_long, 0) num_tokens = expanded_idx_mapping.shape[0] BLOCK_SIZE = 1024 num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 53197a5c81a5..855fc337171c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -106,6 +106,7 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tracing import instrument from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.utils.gpu_sync_debug import gpu_sync_allowed from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.nvtx_pytorch_hooks import PytHooks @@ -987,16 +988,19 @@ def _init_model_kwargs(self): if len(token_type_id_requests) == 0: return model_kwargs - seq_lens = self.seq_lens[:num_reqs] + # Build ids on CPU using the CPU-resident upper bound for seq_lens; + # `torch.arange(seq_lens[i])` with a GPU scalar would force a sync. + seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs].tolist() token_type_ids = [] for i in range(num_reqs): - pos = token_type_id_requests.get(i, seq_lens[i]) - ids = (torch.arange(seq_lens[i]) >= pos).int() + seq_len_i = seq_lens_cpu[i] + pos = token_type_id_requests.get(i, seq_len_i) + ids = (torch.arange(seq_len_i) >= pos).int() token_type_ids.append(ids) model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( - device=self.device + device=self.device, non_blocking=True ) return model_kwargs @@ -1441,9 +1445,13 @@ def _update_states_after_model_execute( self.num_accepted_tokens.gpu[:num_reqs] = (output_token_ids != -1).sum(dim=1) if self.cache_config.mamba_cache_mode == "align": - for i, num_tokens in enumerate( - self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() - ): + # Align mode needs the Python values immediately to call + # `postprocess_mamba` below; unavoidable D2H. Opt-in cache mode. + from vllm.utils.gpu_sync_debug import gpu_sync_allowed + + with gpu_sync_allowed(): + accepted_np = self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() + for i, num_tokens in enumerate(accepted_np): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens mamba_utils.postprocess_mamba( scheduler_output, @@ -2677,10 +2685,9 @@ def _prepare_kv_sharing_fast_prefill( # There might have leftover indices in logits_indices[num_logits:] # from previous iterations, whose values may be greater than the # batch size in the current iteration. To ensure indices are always - # valid, we fill the padded indices with the last index. - self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item() - ) + # valid, fill the padded indices with the last index. Broadcast the + # scalar GPU-side to avoid a D2H sync on `.item()`. + self.kv_sharing_fast_prefill_logits_indices[num_logits:] = logits_indices[-1] # Dispatch for the decoder portion of the model. _, batch_desc = self.cudagraph_dispatcher.dispatch( num_logits, invalid_modes={CUDAGraphMode.FULL} @@ -3272,13 +3279,15 @@ def _preprocess( # If a batch only has token ids, then including the embedding layer # in the CUDA graph will be more performant (like in the else case # below). - token_ids_idx = ( - self.is_token_ids.gpu[:num_scheduled_tokens] - .nonzero(as_tuple=False) - .squeeze(1) - ) + # Find token-id positions on CPU (is_token_ids.np is always kept + # in sync in _prepare_inputs) and upload indices non-blocking, so + # this step does not force a GPU sync from .nonzero(). + is_token_ids = self.is_token_ids.np[:num_scheduled_tokens] + token_ids_idx_np = np.nonzero(is_token_ids)[0] # Some tokens ids may need to become embeds - if token_ids_idx.numel() > 0: + if token_ids_idx_np.size > 0: + token_ids_idx = torch.from_numpy(token_ids_idx_np) + token_ids_idx = token_ids_idx.to(self.device, non_blocking=True) token_ids = self.input_ids.gpu[token_ids_idx] tokens_to_embeds = self.model.embed_input_ids(input_ids=token_ids) self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds @@ -3400,25 +3409,28 @@ def _bookkeeping_sync( invalid_req_indices = [] logprobs_lists = None if not self.use_async_scheduling: - # Get the valid generated tokens. - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = self._to_list(sampled_token_ids) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() - - if logprobs_tensors is not None: - logprobs_lists = logprobs_tensors.tolists() - else: - # Includes spec decode tokens. - valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - discard_sampled_tokens_req_indices, - logprobs_tensors=logprobs_tensors, - ) + with gpu_sync_allowed(): + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[int(i)].clear() + + if logprobs_tensors is not None: + logprobs_lists = logprobs_tensors.tolists() + else: + # Includes spec decode tokens. + valid_sampled_token_ids, logprobs_lists = ( + RejectionSampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + discard_sampled_tokens_req_indices, + logprobs_tensors=logprobs_tensors, + ) + ) else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 36dd5ed7fc65..c92f4d36fcd4 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -45,6 +45,7 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.tracing import instrument +from vllm.utils.gpu_sync_debug import with_gpu_sync_check from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling from vllm.utils.torch_utils import set_random_seed @@ -711,6 +712,16 @@ def compile_or_warm_up_model(self) -> CompilationTimes: # the model initialization and profiling. set_random_seed(self.model_config.seed) + from vllm.compilation.compiler_interface import trigger_inductor_lazy_init + + trigger_inductor_lazy_init(self.device) + + # Warmup / first-compile is done — activate the `VLLM_GPU_SYNC_CHECK` + # gate so subsequent `execute_model` / `sample_tokens` calls enforce it. + from vllm.utils.gpu_sync_debug import enable_gpu_sync_check + + enable_gpu_sync_check() + return CompilationTimes( language_model=self.compilation_config.compilation_time, encoder=self.compilation_config.encoder_compilation_time, @@ -764,12 +775,14 @@ def annotate_profile(self, scheduler_output): return self.profiler.annotate_context_manager(annotation) @torch.inference_mode() + @with_gpu_sync_check def sample_tokens( self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput: return self.model_runner.sample_tokens(grammar_output) @torch.inference_mode() + @with_gpu_sync_check def execute_model( self, scheduler_output: "SchedulerOutput" ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: