diff --git a/benchmarks/benchmark_hidden_state_extraction.py b/benchmarks/benchmark_hidden_state_extraction.py new file mode 100644 index 000000000000..57b3d53a6e17 --- /dev/null +++ b/benchmarks/benchmark_hidden_state_extraction.py @@ -0,0 +1,411 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark hidden state extraction throughput. + +Measures two modes: + 1. Baseline: bulk inference with max_tokens=1, no extraction. + 2. Extract: async hidden state extraction via ExampleHiddenStatesConnector + with N concurrent clients, each consuming hidden states as + soon as their request finishes (overlapping I/O with generation). + +Reports tokens/s and prompts/s for each mode. + +Usage: + python benchmarks/benchmark_hidden_state_extraction.py \ + --model Qwen/Qwen3-0.6B \ + --num-prompts 64 \ + --num-clients 8 \ + --prompt-len 8192 \ + --layers 1 2 3 4 +""" + +import argparse +import asyncio +import os +import time +from concurrent.futures import ThreadPoolExecutor + +import torch +from transformers import AutoConfig + +from vllm import LLM, SamplingParams +from vllm.config.kv_transfer import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + example_hidden_states_connector, +) +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import RequestOutputKind +from vllm.v1.engine.async_llm import AsyncLLM + + +def _make_profiler_config(profile_dir: str) -> dict: + """Build a profiler_config dict for torch profiling.""" + return { + "profiler": "torch", + "torch_profiler_dir": profile_dir, + "torch_profiler_with_stack": True, + } + + +def make_random_prompts( + num_prompts: int, prompt_len: int, vocab_size: int, seed: int = 42 +) -> list[list[int]]: + """Generate lists of random token IDs.""" + # Set seed for reproducibility + torch.manual_seed(seed) + return [ + torch.randint(0, vocab_size, (prompt_len,)).tolist() for _ in range(num_prompts) + ] + + +def cleanup_hidden_states(path: str) -> None: + lock_path = path + ".lock" + if os.path.exists(lock_path): + os.remove(lock_path) + if os.path.exists(path): + os.remove(path) + + +def consume_hidden_states(path: str) -> float: + """Load hidden states from disk and compute per-position mean. + + Returns a single float: the grand mean of all hidden state values. + This forces the benchmark to actually read and reduce the data. + + Uses :func:`load_hidden_states` which acquires a shared flock, + blocking (without polling) until the async writer releases its + exclusive lock. + """ + obj = example_hidden_states_connector.load_hidden_states(path) + hs = obj["hidden_states"] + total = hs.mean().item() + + cleanup_hidden_states(path) + + return total + + +def run_baseline( + model: str, + prompts: list[list[int]], + extra_args: dict, + profile_dir: str | None = None, +) -> dict: + """Baseline: bulk inference, no hidden state extraction.""" + if profile_dir: + extra_args = { + **extra_args, + "profiler_config": _make_profiler_config(profile_dir), + } + llm = LLM(model=model, enable_prefix_caching=False, **extra_args) + sampling_params = SamplingParams(max_tokens=1) + prompt_inputs = [{"prompt_token_ids": p} for p in prompts] + + # Warmup + llm.generate(prompt_inputs[:4], sampling_params, use_tqdm=False) + + if profile_dir: + llm.start_profile() + + t0 = time.perf_counter() + outputs = llm.generate(prompt_inputs, sampling_params, use_tqdm=True) + elapsed = time.perf_counter() - t0 + + if profile_dir: + llm.stop_profile() + + total_prompt_tokens = sum(len(o.prompt_token_ids) for o in outputs) + num_prompts = len(outputs) + + del llm + torch.accelerator.empty_cache() + + return { + "mode": "baseline", + "elapsed_s": elapsed, + "num_prompts": num_prompts, + "total_prompt_tokens": total_prompt_tokens, + "tokens_per_s": total_prompt_tokens / elapsed, + "prompts_per_s": num_prompts / elapsed, + } + + +# ---- Async extraction benchmark ---- + + +async def _client_loop( + engine: AsyncLLM, + prompt_queue: asyncio.Queue, + consume_pool: ThreadPoolExecutor, + results: list[dict], + client_id: int, +): + """A single async client: pulls prompts, submits to engine, consumes + hidden states as soon as each request finishes.""" + loop = asyncio.get_event_loop() + while True: + item = await prompt_queue.get() + if item is None: + prompt_queue.task_done() + break + idx, token_ids = item + + request_id = f"req-{idx}" + sampling_params = SamplingParams( + max_tokens=1, + output_kind=RequestOutputKind.FINAL_ONLY, + ) + + final_output = None + async for output in engine.generate( + request_id=request_id, + prompt={"prompt_token_ids": token_ids}, + sampling_params=sampling_params, + ): + if output.finished: + final_output = output + + # Consume hidden states on a thread (disk I/O) + path = final_output.kv_transfer_params["hidden_states_path"] + mean_val = await loop.run_in_executor(consume_pool, consume_hidden_states, path) + num_tokens = len(final_output.prompt_token_ids) + + results.append( + { + "request_id": request_id, + "num_prompt_tokens": num_tokens, + "mean_hidden_value": mean_val, + } + ) + prompt_queue.task_done() + + +async def _run_extraction_async( + model: str, + prompts: list[list[int]], + num_clients: int, + layers: list[int], + tmpdir: str, + extra_args: dict, + profile_dir: str | None = None, +) -> dict: + if profile_dir: + extra_args = { + **extra_args, + "profiler_config": _make_profiler_config(profile_dir), + } + engine_args = AsyncEngineArgs( + model=model, + enable_prefix_caching=False, + enable_chunked_prefill=False, + speculative_config={ + "method": "extract_hidden_states", + "num_speculative_tokens": 1, + "draft_model_config": { + "hf_config": { + "eagle_aux_hidden_state_layer_ids": layers, + }, + }, + }, + kv_transfer_config=KVTransferConfig( + kv_connector="ExampleHiddenStatesConnector", + kv_role="kv_producer", + kv_connector_extra_config={ + "shared_storage_path": tmpdir, + }, + ), + **extra_args, + ) + engine = AsyncLLM.from_engine_args(engine_args) + + try: + # Warmup: run a few prompts sequentially, cleaning up generated files + for i in range(min(4, len(prompts))): + sp = SamplingParams(max_tokens=1, output_kind=RequestOutputKind.FINAL_ONLY) + final_output = None + async for output in engine.generate( + request_id=f"warmup-{i}", + prompt={"prompt_token_ids": prompts[i]}, + sampling_params=sp, + ): + if output.finished: + final_output = output + if final_output and final_output.kv_transfer_params: + path = final_output.kv_transfer_params.get("hidden_states_path") + if path: + cleanup_hidden_states(path) + + if profile_dir: + await engine.start_profile() + + # Fill prompt queue + prompt_queue: asyncio.Queue = asyncio.Queue() + for idx, token_ids in enumerate(prompts): + prompt_queue.put_nowait((idx, token_ids)) + # Sentinel per client + for _ in range(num_clients): + prompt_queue.put_nowait(None) + + results: list[dict] = [] + consume_pool = ThreadPoolExecutor(max_workers=num_clients) + + t0 = time.perf_counter() + tasks = [ + asyncio.create_task( + _client_loop(engine, prompt_queue, consume_pool, results, i) + ) + for i in range(num_clients) + ] + await asyncio.gather(*tasks) + elapsed = time.perf_counter() - t0 + + consume_pool.shutdown(wait=True) + + if profile_dir: + await engine.stop_profile() + + total_prompt_tokens = sum(r["num_prompt_tokens"] for r in results) + num_prompts = len(results) + mean_hidden = sum(r["mean_hidden_value"] for r in results) / max( + len(results), 1 + ) + + return { + "mode": "extract", + "elapsed_s": elapsed, + "num_prompts": num_prompts, + "total_prompt_tokens": total_prompt_tokens, + "tokens_per_s": total_prompt_tokens / elapsed, + "prompts_per_s": num_prompts / elapsed, + "mean_hidden_value": mean_hidden, + } + finally: + engine.shutdown() + + +def run_extraction( + model: str, + prompts: list[list[int]], + num_clients: int, + layers: list[int], + extra_args: dict, + profile_dir: str | None = None, +) -> dict: + return asyncio.run( + _run_extraction_async( + model, + prompts, + num_clients, + layers, + "/dev/shm", + extra_args, + profile_dir=profile_dir, + ) + ) + + +def print_results(results: dict): + mode = results["mode"] + print(f"\n{'=' * 60}") + print(f" {mode.upper()} RESULTS") + print(f"{'=' * 60}") + print(f" Prompts: {results['num_prompts']}") + print(f" Total prompt tokens: {results['total_prompt_tokens']:,}") + print(f" Wall time: {results['elapsed_s']:.2f}s") + print(f" Tokens/s: {results['tokens_per_s']:,.0f}") + print(f" Prompts/s: {results['prompts_per_s']:.2f}") + if mode == "extract": + print(f" Mean hidden value: {results['mean_hidden_value']:.6f}") + print(f"{'=' * 60}\n") + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark hidden state extraction throughput" + ) + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--num-prompts", type=int, default=64) + parser.add_argument("--num-clients", type=int, default=8) + parser.add_argument("--prompt-len", type=int, default=8192) + parser.add_argument("--layers", type=int, nargs="+", default=[1, 2, 3, 4]) + parser.add_argument("--skip-baseline", action="store_true") + parser.add_argument("--skip-extract", action="store_true") + parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) + parser.add_argument("--max-num-batched-tokens", type=int, default=8192) + parser.add_argument("--max-cudagraph-capture-size", type=int, default=None) + parser.add_argument("--max-model-len", type=int, default=None) + parser.add_argument("--enforce-eager", action="store_true") + parser.add_argument("--load-format", type=str, default=None) + parser.add_argument( + "--profile", + action="store_true", + help="Enable torch profiler for both baseline and extraction runs.", + ) + parser.add_argument( + "--torch-profiler-dir", + type=str, + default="./vllm_profile", + help="Directory to save torch profiler traces (default: ./vllm_profile).", + ) + parser.add_argument( + "--enable-flashinfer-autotune", + action="store_true", + default=False, + help="Enable FlashInfer autotuning (can be slow).", + ) + args = parser.parse_args() + + extra_args = { + "gpu_memory_utilization": args.gpu_memory_utilization, + "max_num_batched_tokens": args.max_num_batched_tokens, + } + if args.max_model_len is not None: + extra_args["max_model_len"] = args.max_model_len + if args.enforce_eager: + extra_args["enforce_eager"] = True + if args.load_format is not None: + extra_args["load_format"] = args.load_format + if args.max_cudagraph_capture_size is not None: + extra_args["max_cudagraph_capture_size"] = args.max_cudagraph_capture_size + extra_args["enable_flashinfer_autotune"] = args.enable_flashinfer_autotune + + # Get vocab size from HF config without loading the full model + hf_config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) + vocab_size = hf_config.vocab_size + prompts = make_random_prompts(args.num_prompts, args.prompt_len, vocab_size) + print( + f"Generated {args.num_prompts} prompts, " + f"{args.prompt_len} tokens each (vocab {vocab_size})" + ) + + profile_dir = args.torch_profiler_dir if args.profile else None + if profile_dir: + print(f"Torch profiler enabled, traces will be saved to {profile_dir}/") + + if not args.skip_baseline: + baseline_profile_dir = f"{profile_dir}/baseline" if profile_dir else None + baseline = run_baseline( + args.model, prompts, extra_args, profile_dir=baseline_profile_dir + ) + print_results(baseline) + + if not args.skip_extract: + extract_profile_dir = f"{profile_dir}/extract" if profile_dir else None + extract = run_extraction( + args.model, + prompts, + args.num_clients, + args.layers, + extra_args, + profile_dir=extract_profile_dir, + ) + print_results(extract) + + if not args.skip_baseline and not args.skip_extract: + slowdown = baseline["tokens_per_s"] / extract["tokens_per_s"] + print("Extraction slowdown factor: {:.2f}x".format(slowdown)) + + +if __name__ == "__main__": + main() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py index fcd1f365a715..a6120f128a35 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import fcntl import os +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional -import safetensors import torch from vllm.config import VllmConfig, get_layers_from_vllm_config @@ -39,6 +40,26 @@ def extract_from_kv_cache( return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size] +def load_hidden_states(path: str) -> dict[str, torch.Tensor]: + """Load hidden states written by ExampleHiddenStatesConnector. + + Blocks (without polling) until the async write is complete by + acquiring a shared flock on the companion lock file. The kernel + puts the caller to sleep until the writer releases its exclusive lock. + + Args: + path: The file path returned in kv_transfer_params["hidden_states_path"]. + + Returns: + Dict with "hidden_states" and "token_ids" tensors. + """ + lock_path = path + ".lock" + with open(lock_path) as lf: + fcntl.flock(lf, fcntl.LOCK_SH) # sleeps until writer releases LOCK_EX + data = torch.load(path, map_location="cpu") + return data + + @dataclass class ReqMeta: # Request ID @@ -148,17 +169,67 @@ def __init__( self._active_requests: dict[str, NewRequestData] = {} self._req_blocks: dict[str, list[int]] = {} + # Async write infrastructure (worker-side). + # Dedicated CUDA stream for DtoH copies so they don't block + # the default stream (model forward). Thread pool for disk writes. + self._copy_stream: torch.cuda.Stream | None = None # lazy init + self._executor = ThreadPoolExecutor( + max_workers=8, thread_name_prefix="vllm-hs-save" + ) + # (tensors_dict, copy_done_event, filename, req_id) queued by + # save_kv_layer, submitted to thread pool by wait_for_save. + self._pending_copies: list[ + tuple[dict[str, torch.Tensor], torch.cuda.Event, str, str] + ] = [] + # req_id → most recent in-flight Future for that req_id. + self._req_futures: dict[str, Future] = {} + # req_ids reported as finished-generating by the scheduler, + # accumulated across get_finished calls. + self._accumulated_finished_req_ids: set[str] = set() + + def _get_copy_stream(self) -> torch.cuda.Stream: + """Lazily create the copy stream (CUDA must be initialized).""" + if self._copy_stream is None: + self._copy_stream = torch.cuda.Stream() + return self._copy_stream + # ============================== # Worker-side methods # ============================== def start_load_kv(self, *args, **kwargs: Any) -> None: - pass # Empty implementation of abstract method + pass # Store-only connector — nothing to load def wait_for_layer_load(self, layer_name: str) -> None: - pass # Empty implementation of abstract method + pass # Store-only connector — nothing to load def wait_for_save(self): - pass # Empty implementation of abstract method + """Submit pending async copies to the thread pool for disk write. + + For each pending write we acquire an exclusive flock on a + companion ``.lock`` file **before** submitting to the thread pool. + The thread worker releases the lock after the data file is fully + written. Clients call :func:`load_hidden_states` which takes a + shared flock — the kernel sleeps the client until the writer is + done. Because ``wait_for_save`` runs before the worker returns + output to the scheduler, the lock file is guaranteed to exist + (and be held) by the time the client receives the path. + """ + for tensors, event, filename, req_id in self._pending_copies: + prior = self._req_futures.get(req_id) + assert prior is None, "Found another KV transfer request with same req_id!" + + # Create/open the lock file and acquire an exclusive lock. + # The lock is held by this fd; the thread worker will close + # the fd after writing, which releases the lock. + lock_path = filename + ".lock" + lock_fd = os.open(lock_path, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o644) + fcntl.flock(lock_fd, fcntl.LOCK_EX) + + future = self._executor.submit( + self._write_tensors, tensors, event, filename, lock_fd + ) + self._req_futures[req_id] = future + self._pending_copies.clear() def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): from vllm.model_executor.models.extract_hidden_states import ( @@ -174,6 +245,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}" ) + @staticmethod + def _write_tensors( + tensors: dict[str, torch.Tensor], + event: torch.cuda.Event, + filename: str, + lock_fd: int, + ) -> None: + """Thread worker: wait for async DtoH copy, write to disk, release lock. + + ``lock_fd`` is an open file descriptor on the companion ``.lock`` + file with ``LOCK_EX`` already held. Closing it releases the lock, + which unblocks any client sleeping on ``LOCK_SH``. + """ + try: + event.synchronize() + torch.save(tensors, filename) + finally: + os.close(lock_fd) # releases LOCK_EX + def save_kv_layer( self, layer_name: str, @@ -184,6 +274,10 @@ def save_kv_layer( """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. + Launches an async DtoH copy on a dedicated CUDA stream. The + actual disk write is deferred to wait_for_save() which submits + it to a thread pool. + Args: layer_name (str): the name of the layer. kv_layer (torch.Tensor): the paged KV buffer of the current @@ -206,15 +300,46 @@ def save_kv_layer( assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata) os.makedirs(self._storage_path, exist_ok=True) + + copy_stream = self._get_copy_stream() + + # Ensure the copy stream sees all prior writes on the default stream. + ready_event = torch.cuda.Event() + ready_event.record() + copy_stream.wait_event(ready_event) + for request in connector_metadata.requests: - hidden_states = extract_from_kv_cache( - kv_layer, request.slot_mapping, request.token_ids.shape[0] + with torch.cuda.stream(copy_stream): + # Move the CPU slot_mapping to GPU on the copy stream so the + # implicit H2D inside fancy indexing doesn't sync the default + # stream. + slot_mapping_gpu = request.slot_mapping.to( + device=kv_layer.device, non_blocking=True + ) + hidden_states_gpu = extract_from_kv_cache( + kv_layer, slot_mapping_gpu, request.token_ids.shape[0] + ) + # Async DtoH copy into pinned host memory. + pinned_hs = torch.empty_like( + hidden_states_gpu, device="cpu", pin_memory=True + ) + pinned_hs.copy_(hidden_states_gpu, non_blocking=True) + + # Record completion of this copy on the copy stream. + copy_done = torch.cuda.Event() + copy_done.record(copy_stream) + + # token_ids is already on CPU (created in ReqMeta.make_meta). + assert not request.token_ids.is_cuda, ( + "Expected token_ids on CPU, got CUDA tensor" ) tensors = { - "hidden_states": hidden_states.detach().cpu(), - "token_ids": request.token_ids.detach().cpu(), + "hidden_states": pinned_hs, + "token_ids": request.token_ids.clone(), } - safetensors.torch.save_file(tensors, request.filename) + self._pending_copies.append( + (tensors, copy_done, request.filename, request.req_id) + ) # ============================== # Scheduler-side methods @@ -264,7 +389,7 @@ def build_connector_meta( meta = ExampleHiddenStatesConnectorMetadata() for new_req in scheduler_output.scheduled_new_reqs: token_ids = new_req.prompt_token_ids or [] - filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors") + filename = os.path.join(self._storage_path, f"{new_req.req_id}.pt") meta.add_request( new_req.req_id, filename=filename, @@ -329,7 +454,31 @@ def request_finished( _ = self._active_requests.pop(req_id, None) _ = self._req_blocks.pop(req_id, None) - return False, {"hidden_states_path": req_filename} + return True, {"hidden_states_path": req_filename} + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + """Poll async write completion for requests that finished generating. + + The scheduler passes finished_req_ids to tell the worker which + requests are done generating. We accumulate these across calls + and return a request as "finished sending" once its disk write + Future is complete (or if it never had a pending write). + """ + self._accumulated_finished_req_ids.update(finished_req_ids) + + done_sending: set[str] = set() + for req_id in list(self._accumulated_finished_req_ids): + future = self._req_futures.get(req_id) + if future is None or future.done(): + if future is not None: + future.result() # propagate write exceptions + del self._req_futures[req_id] + done_sending.add(req_id) + self._accumulated_finished_req_ids.discard(req_id) + + return done_sending or None, None @classmethod def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: diff --git a/vllm/v1/spec_decode/extract_hidden_states.py b/vllm/v1/spec_decode/extract_hidden_states.py index dd4e47d45a6d..9ae0007ffa9d 100644 --- a/vllm/v1/spec_decode/extract_hidden_states.py +++ b/vllm/v1/spec_decode/extract_hidden_states.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING +import numpy as np import torch import torch.nn as nn @@ -12,8 +13,10 @@ from vllm.forward_context import set_forward_context from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -49,6 +52,14 @@ def __init__(self, vllm_config: VllmConfig, device): vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size ) + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True, + ) + self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None) if not layer_ids: @@ -300,19 +311,19 @@ def prepare_next_token_ids_padded( (if valid and not discarded) or a backup token from the request state. """ num_reqs = gpu_input_batch.num_reqs - device = sampled_token_ids.device # Compute backup tokens for discarded / invalid requests - backup_tokens_gpu = torch.tensor( + self.backup_next_token_ids.np[:num_reqs] = np.array( [ requests[gpu_input_batch.req_ids[i]].get_token_id( common_attn_metadata.seq_lens_cpu[i].item() ) for i in range(num_reqs) ], - dtype=torch.int32, - device=device, + dtype=np.int32, ) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + backup_tokens_gpu = self.backup_next_token_ids.gpu[:num_reqs] assert discard_request_mask.dtype == torch.bool