Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ def __init__(
# Cached outputs.
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
self.transfer_event = torch.cuda.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_model_len, 1),
dtype=torch.int64,
device="cpu",
pin_memory=True)

def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
Expand Down Expand Up @@ -1691,7 +1697,7 @@ def execute_model(
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
valid_sampled_token_ids = self._to_list(sampled_token_ids)
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
Expand Down Expand Up @@ -2219,7 +2225,7 @@ def _dummy_run(
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is
needed.
force_attention: If True, always create attention metadata. Used to
force_attention: If True, always create attention metadata. Used to
warm up attention backend when mode is NONE.
uniform_decode: If True, the batch is a uniform decode batch.
skip_eplb: If True, skip EPLB state update.
Expand Down Expand Up @@ -3233,3 +3239,18 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
mamba_type=mamba_module.mamba_type)

return kv_cache_spec

def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
# This is a short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/22754.
# `tolist` would trigger a cuda wise stream sync, which
# would block other copy ops from other cuda streams.
# A cuda event sync would avoid such a situation. Since
# this is in the critical path of every single model
# forward loop, this has caused perf issue for a disagg
# setup.
pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]]
pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record()
self.transfer_event.synchronize()
return pinned.tolist()