diff --git a/benchmarks/multi_turn_tq/BENCHMARK_REPORT.md b/benchmarks/multi_turn_tq/BENCHMARK_REPORT.md new file mode 100644 index 000000000000..c3aa0600420f --- /dev/null +++ b/benchmarks/multi_turn_tq/BENCHMARK_REPORT.md @@ -0,0 +1,126 @@ +# Multi-Turn KV Cache Compression Benchmark Report + +## Summary + +Comparing KV cache compression strategies on MiniMax-M2.7 with TP=2, simulating 192GB GPU memory. + +**TurboQuant 4-bit achieves 85.7% cache hit rate vs FP8's 27.5%**, resulting in: + +- **4.6× faster TTFT** than BF16 baseline (8.0s vs 37s) +- **2.4× faster TTFT** than FP8 (8.0s vs 19.5s) +- **1.3× faster total duration** than FP8 (492s vs 639s) + +## Configuration + +| Parameter | Value | +|-----------|-------| +| Model | MiniMax-M2.7 | +| TP Size | 2 | +| GPU Memory Util | 0.6 | +| Clients | 40 | +| Rounds | 8 | +| Common Prefix | 2,000 tokens | +| Per-Client Prefix | 32,000 tokens | +| Input/Round | 2,000 tokens | +| Output/Round | 200 tokens | + +## Results + +### Overall Metrics + +| Metric | BF16 | FP8 | TQ 4-bit | +|--------|------|-----|----------| +| TTFT Mean | 36,986 ms | 19,497 ms | **8,006 ms** | +| TTFT P90 | 63,394 ms | 43,452 ms | **10,580 ms** | +| Cache Hit Rate | 6.6% | 27.5% | **85.7%** | +| Throughput | 12,154 tok/s | 16,159 tok/s | **21,038 tok/s** | +| Total Duration | 849s | 639s | **492s** | + +### Per-Round Cache Hit Rate + +| Round | BF16 | FP8 | TQ 4-bit | +|-------|------|-----|----------| +| 0 | 5.8% | 5.8% | 5.8% | +| 1 | 10.0% | 73.5% | **93.7%** | +| 2 | 7.5% | 48.4% | **94.1%** | +| 3 | 10.7% | 31.3% | **94.4%** | +| 4 | 7.0% | 22.4% | **94.7%** | +| 5 | 4.5% | 18.5% | **95.0%** | +| 6 | 4.3% | 15.6% | **95.2%** | +| 7 | 4.1% | 13.2% | **95.4%** | + +### Per-Round TTFT (ms) + +| Round | BF16 | FP8 | TQ 4-bit | +|-------|------|-----|----------| +| 0 | 27,723 | 21,111 | 29,844 | +| 1 | 28,104 | 3,676 | **4,289** | +| 2 | 31,519 | 11,413 | **4,382** | +| 3 | 33,229 | 16,977 | **4,701** | +| 4 | 37,693 | 20,896 | **4,903** | +| 5 | 42,181 | 23,893 | **5,002** | +| 6 | 45,708 | 26,919 | **5,295** | +| 7 | 49,731 | 31,086 | **5,628** | + +## Reproduction Steps + +```bash +cd benchmarks/multi_turn_tq + +# BF16 Baseline +HIP_VISIBLE_DEVICES=4,5 \ +GPU_MEMORY_UTIL=0.6 \ +MAX_MODEL_LEN=80000 \ +OUTPUT_TOKENS=200 \ +SUB_QUESTION_TOKENS=2000 \ +ATTENTION_BACKEND=ROCM_AITER_FA \ +./run_benchmark.sh \ + --kv-cache-dtype auto \ + --tag fix_baseline \ + --num-clients 40 \ + --num-rounds 8 \ + --common-prefix 2000 \ + --prefix-tokens 32000 \ + --port 6789 + +# FP8 +HIP_VISIBLE_DEVICES=2,3 \ +GPU_MEMORY_UTIL=0.6 \ +MAX_MODEL_LEN=80000 \ +OUTPUT_TOKENS=200 \ +SUB_QUESTION_TOKENS=2000 \ +ATTENTION_BACKEND=ROCM_AITER_FA \ +./run_benchmark.sh \ + --kv-cache-dtype fp8_e4m3 \ + --tag fix_fp8 \ + --num-clients 40 \ + --num-rounds 8 \ + --common-prefix 2000 \ + --prefix-tokens 32000 \ + --port 6791 + +# TurboQuant 4-bit +HIP_VISIBLE_DEVICES=6,7 \ +GPU_MEMORY_UTIL=0.6 \ +MAX_MODEL_LEN=80000 \ +OUTPUT_TOKENS=200 \ +SUB_QUESTION_TOKENS=2000 \ +VLLM_TQ_DECODE_V3=1 \ +./run_benchmark.sh \ + --kv-cache-dtype turboquant_4bit_nc \ + --tag fix3_tq4bit \ + --num-clients 40 \ + --num-rounds 8 \ + --common-prefix 2000 \ + --prefix-tokens 32000 \ + --port 6790 + +# Compare results +python compare_results.py results/multiturn/results_fix_baseline_*.json results/multiturn/results_fix_fp8_*.json results/multiturn/results_fix3_tq4bit_*.json +``` + +## Result Files + +- `results/multiturn/results_fix_baseline_20260429_022339.json` +- `results/multiturn/results_fix_fp8_20260429_022010.json` +- `results/multiturn/results_fix3_tq4bit_20260429_202437.json` diff --git a/benchmarks/multi_turn_tq/SKILL_MULTITURN_BENCHMARK.md b/benchmarks/multi_turn_tq/SKILL_MULTITURN_BENCHMARK.md new file mode 100644 index 000000000000..d5b1ad10a5b0 --- /dev/null +++ b/benchmarks/multi_turn_tq/SKILL_MULTITURN_BENCHMARK.md @@ -0,0 +1,232 @@ +--- +name: multiturn-benchmark +description: > + Benchmark KV cache compression strategies (BF16 vs FP8 vs TurboQuant 4-bit) on vLLM + with multi-turn workloads on MiniMax-M2.7 / MI300X/MI355x. Finds scenarios where compression + outperforms baseline by creating memory pressure. Key result: TurboQuant achieves 85% + cache hit vs FP8's 27% under memory pressure, 2.6× faster TTFT. + Usage: /multiturn-benchmark [baseline|fp8|tq4bit] [--num-clients N] [--num-rounds N] +allowed-tools: Bash, Read, Grep, Glob +--- + +# Multi-Turn KV Cache Compression Benchmarking + +**Tags**: vllm, kv-cache, compression, turboquant, fp8, multi-turn, benchmark, prefix-caching +**Model**: MiniMax-M2.7 +**Hardware**: MI300X/MI355X (ROCm) + +## Overview + +This skill covers benchmarking KV cache compression strategies (BF16, FP8, TurboQuant 4-bit) on vLLM with multi-turn workloads. The goal is to find scenarios where compression techniques outperform baseline by creating memory pressure. + +## Key Concepts + +### When KV Cache Compression Shines + +- Compression benefits appear when **memory is the bottleneck** +- Need: `total_tokens > GPU_capacity / compression_ratio` +- Under light load, baseline wins (no memory pressure, compression has overhead) +- Under heavy load, compression wins (avoids cache eviction) + +### Cache Hit Rate + +- `cache_hit_rate = cached_tokens / prompt_tokens` +- High rate (90%+) = KV cache reused, fast TTFT +- Low rate (<20%) = cache evicted, must recompute, slow TTFT +- Per-round rate should increase in later rounds (history cached) + +### Memory Capacity Calculation + +``` +KV cache per token = 2 (K+V) × num_kv_heads × head_dim × dtype_bytes × num_layers + +For MiniMax-M2.7 with TP=2: +- Layers: 62, KV heads: 8 (4 per GPU), Head dim: 128 +- BF16: 127 KB/token +- FP8 (2×): 63.5 KB/token +- 4-bit (4×): 31.75 KB/token + +Capacity at 0.6 util on 288GB GPU (simulating 192GB): +- BF16: ~756k tokens +- FP8: ~1.5M tokens +- 4-bit: ~3M tokens +``` + +## Directory Structure + +``` +benchmarks/multi_turn_tq/ +├── run_benchmark.sh # Main wrapper script +├── bench_multiturn_enhanced.py # Python benchmark with per-round metrics +├── compare_results.py # Compare multiple result files +├── results/multiturn/ # JSON result files +├── logs/ # Server and benchmark logs +├── BENCHMARK_REPORT.md # Latest results report +└── SKILL_MULTITURN_BENCHMARK.md # This file +``` + +## Key Scripts + +### run_benchmark.sh + +Launches vLLM server and runs benchmark. Key environment variables: + +| Variable | Description | Default | +|----------|-------------|---------| +| `GPU_MEMORY_UTIL` | Fraction of GPU memory to use | 0.9 | +| `MAX_MODEL_LEN` | Max context length | 8192 | +| `OUTPUT_TOKENS` | Max output tokens per round | 100 | +| `SUB_QUESTION_TOKENS` | Input tokens per follow-up round | 200 | +| `ATTENTION_BACKEND` | Attention backend (empty for auto) | - | +| `KV_SKIP_LAYERS` | Layers to skip for KV quantization | - | +| `VLLM_TQ_DECODE_V3` | Enable TurboQuant decode v3 | - | + +Command line args: `--kv-cache-dtype`, `--tag`, `--num-clients`, `--num-rounds`, `--common-prefix`, `--prefix-tokens`, `--port`, `--skip-server` + +### bench_multiturn_enhanced.py + +Python benchmark that: + +- Runs multi-turn conversations with round barrier mode +- Captures actual `cached_tokens` from server (requires `--enable-prompt-tokens-details`) +- Records per-round TTFT and cache hit rate +- Uses actual model responses in history (critical for prefix caching!) + +## Important Fixes Applied + +### 1. History Must Use Actual Response + +**Bug**: Original code used placeholder `"[Response for round N]"` instead of actual model output. +**Impact**: Prefix caching fails because history doesn't match cached KV. +**Fix**: Store `result.generated_text` and use it in history. + +### 2. TurboQuant Server Flags + +```bash +VLLM_TQ_DECODE_V3=1 # Enable decode v3 +KV_SKIP_LAYERS="" # Pass --kv-cache-dtype-skip-layers "" +# Do NOT set ATTENTION_BACKEND - let TQ auto-select +``` + +### 3. TurboQuant Memory Overhead + +Known bug: TQ uses ~22GB more than expected. Workaround: increase `GPU_MEMORY_UTIL` by ~0.08 for TQ runs. + +## Parameter Tuning Guide + +### To Create Memory Pressure + +Increase total tokens until baseline starts evicting: + +``` +total_tokens = num_clients × tokens_per_client +tokens_per_client = common_prefix + unique_prefix + rounds × (output + input) +``` + +### To See Gradual Degradation + +Find settings where: + +- Round 0-2: All fit +- Round 3-5: FP8 starts evicting, TQ still fits +- Round 6+: FP8 heavily evicting, TQ starts evicting + +### Example Configurations + +**Light load (no pressure)**: + +- 16 clients, 8k prefix, 5 rounds → All methods similar + +**Medium load (BF16 evicts)**: + +- 40 clients, 16k prefix, 10 rounds → BF16 degrades, FP8/TQ OK + +**Heavy load (FP8 evicts)**: + +- 40 clients, 32k prefix, 8 rounds, 2k input/round → FP8 degrades, TQ wins + +## Reproduction Commands + +### Quick 3-Way Comparison + +```bash +cd benchmarks/multi_turn_tq + +# Run all 3 in parallel on different GPU pairs +HIP_VISIBLE_DEVICES=4,5 GPU_MEMORY_UTIL=0.6 MAX_MODEL_LEN=80000 \ +OUTPUT_TOKENS=200 SUB_QUESTION_TOKENS=2000 ATTENTION_BACKEND=ROCM_AITER_FA \ +./run_benchmark.sh --kv-cache-dtype auto --tag test_baseline \ +--num-clients 40 --num-rounds 8 --common-prefix 2000 --prefix-tokens 32000 --port 6789 & + +HIP_VISIBLE_DEVICES=2,3 GPU_MEMORY_UTIL=0.6 MAX_MODEL_LEN=80000 \ +OUTPUT_TOKENS=200 SUB_QUESTION_TOKENS=2000 ATTENTION_BACKEND=ROCM_AITER_FA \ +./run_benchmark.sh --kv-cache-dtype fp8_e4m3 --tag test_fp8 \ +--num-clients 40 --num-rounds 8 --common-prefix 2000 --prefix-tokens 32000 --port 6791 & + +HIP_VISIBLE_DEVICES=6,7 GPU_MEMORY_UTIL=0.68 MAX_MODEL_LEN=80000 \ +OUTPUT_TOKENS=200 SUB_QUESTION_TOKENS=2000 VLLM_TQ_DECODE_V3=1 KV_SKIP_LAYERS="" \ +./run_benchmark.sh --kv-cache-dtype turboquant_4bit_nc --tag test_tq4bit \ +--num-clients 40 --num-rounds 8 --common-prefix 2000 --prefix-tokens 32000 --port 6790 & + +wait + +# Compare +python compare_results.py results/multiturn/results_test_*.json +``` + +### Simulating Different GPU Sizes + +```bash +# 288GB GPU at 0.6 util ≈ 173GB (simulates 192GB) +# 288GB GPU at 0.5 util ≈ 144GB (simulates 160GB) +# 288GB GPU at 0.4 util ≈ 115GB (simulates 128GB) +``` + +## Interpreting Results + +### Good Result Pattern + +``` +Per-Round Cache Hit Rate: +Round | BF16 | FP8 | TQ +0 | 5.8% | 5.8% | 5.8% <- All same (cold start) +1 | 10% | 73% | 94% <- TQ >> FP8 >> BF16 +2 | 7% | 48% | 94% <- FP8 degrading, TQ stable +... +7 | 4% | 13% | 95% <- TQ maintains high hit rate +``` + +### Warning Signs + +- Cache hit rate same across all methods → not enough memory pressure +- TQ cache hit dropping → exceeded even TQ capacity +- FP8 similar to TQ → workload too light for 4-bit advantage + +## Troubleshooting + +### Server Fails to Start + +- Check logs in `logs/server_*.log` +- Verify GPU memory available: `rocm-smi --showmeminfo vram` +- Kill stale processes: `pkill -9 -f "vllm serve"` + +### Low Cache Hit Despite Multi-Turn + +- Verify history uses actual model response (not placeholder) +- Check `--enable-prompt-tokens-details` flag on server +- Ensure `--enable-prefix-caching` is set + +### OOM Errors + +- Reduce `--num-clients` or `--prefix-tokens` +- Lower `GPU_MEMORY_UTIL` +- Check for zombie GPU processes: `rocm-smi --showpids` + +## Files Reference + +| File | Purpose | +|------|---------| +| `results_*.json` | Raw benchmark results with per-round metrics | +| `all_results.jsonl` | Appended results for tracking | +| `server_*.log` | vLLM server output, includes cache stats | +| `benchmark_*.log` | Benchmark script output | diff --git a/benchmarks/multi_turn_tq/bench_multiturn_enhanced.py b/benchmarks/multi_turn_tq/bench_multiturn_enhanced.py new file mode 100755 index 000000000000..9a83e1feea06 --- /dev/null +++ b/benchmarks/multi_turn_tq/bench_multiturn_enhanced.py @@ -0,0 +1,653 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Enhanced Multi-Turn Benchmark for KV Cache Compression Comparison + +This script benchmarks multi-turn conversations with vLLM, capturing: +- Per-round TTFT (Time to First Token) +- Actual cached_tokens from server (not estimated) +- Cache hit rate per round and overall + +Inspired by SGLang's bench_multiturn.py but works with vLLM's OpenAI API. + +Usage: + # Start server with --enable-prompt-tokens-details + vllm serve MODEL --enable-prompt-tokens-details --enable-prefix-caching ... + + # Run benchmark + python bench_multiturn_enhanced.py --num-clients 16 --num-rounds 5 + + # Compare KV cache strategies + python bench_multiturn_enhanced.py --tag baseline + python bench_multiturn_enhanced.py --tag tq4bit +""" + +import argparse +import asyncio +import json +import random +import time +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from statistics import mean, median +from typing import Any + +import aiohttp +import numpy as np +from tqdm.asyncio import tqdm + + +@dataclass +class RequestResult: + """Result of a single request.""" + + success: bool + ttft: float = 0.0 # Time to first token (seconds) + latency: float = 0.0 # Total latency (seconds) + prompt_tokens: int = 0 + cached_tokens: int = 0 + completion_tokens: int = 0 + generated_text: str = "" # Actual model response for history + error: str = "" + + +@dataclass +class RoundMetrics: + """Metrics for a single round across all clients.""" + + ttft: list[float] = field(default_factory=list) + latency: list[float] = field(default_factory=list) + prompt_tokens: list[int] = field(default_factory=list) + cached_tokens: list[int] = field(default_factory=list) + completion_tokens: list[int] = field(default_factory=list) + + @property + def cache_hit_rate(self) -> float: + total_prompt = sum(self.prompt_tokens) + total_cached = sum(self.cached_tokens) + return total_cached / total_prompt if total_prompt > 0 else 0.0 + + @property + def avg_ttft(self) -> float: + return mean(self.ttft) if self.ttft else 0.0 + + @property + def avg_latency(self) -> float: + return mean(self.latency) if self.latency else 0.0 + + +def percentile(values: list[float], p: float) -> float: + """Calculate percentile of a list.""" + if not values: + return 0.0 + sorted_vals = sorted(values) + idx = int(p * len(sorted_vals)) + if idx >= len(sorted_vals): + idx = len(sorted_vals) - 1 + return sorted_vals[idx] + + +async def send_chat_request( + session: aiohttp.ClientSession, + url: str, + messages: list[dict], + model: str, + max_tokens: int, + timeout: float = 300.0, +) -> RequestResult: + """Send a chat completion request and measure timing.""" + + payload = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0.0, + "stream": True, + "stream_options": {"include_usage": True}, # Key: get cached_tokens! + } + + result = RequestResult(success=False) + start_time = time.perf_counter() + ttft_recorded = False + generated_text = "" + + try: + async with session.post( + url, + json=payload, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response: + if response.status != 200: + result.error = f"HTTP {response.status}: {await response.text()}" + return result + + async for line in response.content: + line = line.decode("utf-8").strip() + if not line or not line.startswith("data: "): + continue + + data_str = line[6:] # Remove "data: " prefix + if data_str == "[DONE]": + break + + try: + data = json.loads(data_str) + except json.JSONDecodeError: + continue + + # Record TTFT on first content + if not ttft_recorded: + choices = data.get("choices", []) + if choices and choices[0].get("delta", {}).get("content"): + result.ttft = time.perf_counter() - start_time + ttft_recorded = True + + # Capture content + choices = data.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + if delta.get("content"): + generated_text += delta["content"] + + # Capture usage info (comes in final chunk) + usage = data.get("usage") + if usage: + result.prompt_tokens = usage.get("prompt_tokens", 0) + result.completion_tokens = usage.get("completion_tokens", 0) + + # Get cached_tokens from prompt_tokens_details + details = usage.get("prompt_tokens_details") + if details: + result.cached_tokens = details.get("cached_tokens", 0) + + result.latency = time.perf_counter() - start_time + result.success = True + result.generated_text = generated_text # Store actual response + + # If TTFT wasn't recorded (single chunk), use latency + if not ttft_recorded: + result.ttft = result.latency + + except asyncio.TimeoutError: + result.error = "Request timeout" + except Exception as e: + result.error = str(e) + + return result + + +async def run_round( + session: aiohttp.ClientSession, + url: str, + model: str, + client_histories: list[list[dict]], + max_tokens: int, + max_parallel: int, + pbar: tqdm, +) -> list[RequestResult]: + """Run one round of requests for all clients.""" + + semaphore = asyncio.Semaphore(max_parallel) + + async def send_one(messages: list[dict]) -> RequestResult: + async with semaphore: + result = await send_chat_request(session, url, messages, model, max_tokens) + pbar.update(1) + return result + + tasks = [send_one(history) for history in client_histories] + return await asyncio.gather(*tasks) + + +def generate_text_tokens(tokenizer_name: str, num_tokens: int, seed: int = 42) -> str: + """Generate random text of approximately num_tokens length.""" + # Simple word-based generation (approximately 1.3 tokens per word for English) + random.seed(seed) + words = [ + "the", + "quick", + "brown", + "fox", + "jumps", + "over", + "lazy", + "dog", + "hello", + "world", + "this", + "is", + "a", + "test", + "message", + "for", + "benchmarking", + "language", + "model", + "performance", + "with", + "cache", + "optimization", + "and", + "compression", + "techniques", + "that", + "improve", + "throughput", + "latency", + "memory", + "efficiency", + "in", + "production", + "systems", + "running", + "large", + "scale", + "inference", + "workloads", + "artificial", + "intelligence", + "machine", + "learning", + "deep", + "neural", + "networks", + "transformers", + "attention", + "mechanism", + "embeddings", + ] + + # Approximate tokens needed (assuming ~1.3 tokens per word) + num_words = int(num_tokens / 1.3) + text = " ".join(random.choices(words, k=num_words)) + return text + + +class MultiTurnBenchmark: + """Multi-turn benchmark with round barrier mode.""" + + def __init__(self, args): + self.args = args + self.url = f"http://{args.host}:{args.port}/v1/chat/completions" + self.round_metrics: dict[int, RoundMetrics] = {} + + # Initialize client histories with initial prompts + self.client_histories: list[list[dict]] = [] + for i in range(args.num_clients): + # Generate unique initial prompt for each client + seed = args.seed + i + + # Common prefix (shared across all clients for prefix caching) + common_text = generate_text_tokens( + args.model, args.common_prefix_tokens, seed=args.seed + ) + + # Per-client unique context + unique_text = generate_text_tokens( + args.model, args.prefix_tokens, seed=seed + ) + + # Initial user message + initial_content = ( + f"{common_text}\n\n" + f"Context for conversation {i}: {unique_text}\n\n" + f"Please summarize the above context and answer any follow-up questions." + ) + + self.client_histories.append([{"role": "user", "content": initial_content}]) + + async def run(self) -> dict[str, Any]: + """Run the benchmark with round barrier mode.""" + + total_requests = self.args.num_clients * self.args.num_rounds + + print("\nStarting Multi-Turn Benchmark") + print(f" Clients: {self.args.num_clients}") + print(f" Rounds: {self.args.num_rounds}") + print(f" Total requests: {total_requests}") + print(f" Common prefix tokens: {self.args.common_prefix_tokens}") + print(f" Per-client prefix tokens: {self.args.prefix_tokens}") + print(f" Max output tokens: {self.args.output_tokens}") + print(f" Max parallel: {self.args.max_parallel}") + print() + + connector = aiohttp.TCPConnector(limit=self.args.max_parallel * 2) + timeout = aiohttp.ClientTimeout(total=600) + + async with aiohttp.ClientSession( + connector=connector, timeout=timeout + ) as session: + pbar = tqdm(total=total_requests, desc="Requests") + start_time = time.perf_counter() + + for round_num in range(self.args.num_rounds): + self.round_metrics[round_num] = RoundMetrics() + + # Send all requests for this round + results = await run_round( + session=session, + url=self.url, + model=self.args.model, + client_histories=self.client_histories, + max_tokens=self.args.output_tokens, + max_parallel=self.args.max_parallel, + pbar=pbar, + ) + + # Process results and update histories + for i, result in enumerate(results): + if not result.success: + print(f"Round {round_num}, Client {i} failed: {result.error}") + continue + + # Record metrics + metrics = self.round_metrics[round_num] + metrics.ttft.append(result.ttft) + metrics.latency.append(result.latency) + metrics.prompt_tokens.append(result.prompt_tokens) + metrics.cached_tokens.append(result.cached_tokens) + metrics.completion_tokens.append(result.completion_tokens) + + # Update history with actual assistant response + # This is critical for prefix caching to work correctly! + self.client_histories[i].append( + {"role": "assistant", "content": result.generated_text} + ) + + # Add next user message (sub-question) + if round_num < self.args.num_rounds - 1: + sub_question = generate_text_tokens( + self.args.model, + self.args.sub_question_tokens, + seed=self.args.seed + round_num * 1000 + i, + ) + self.client_histories[i].append( + { + "role": "user", + "content": f"Follow-up question {round_num + 1}: {sub_question}", + } + ) + + # Print round summary + metrics = self.round_metrics[round_num] + print( + f"\n Round {round_num}: " + f"TTFT={metrics.avg_ttft:.3f}s, " + f"Cache Hit Rate={metrics.cache_hit_rate:.2%}, " + f"Cached={sum(metrics.cached_tokens)}/{sum(metrics.prompt_tokens)} tokens" + ) + + pbar.close() + total_time = time.perf_counter() - start_time + + return self._generate_report(total_time) + + def _generate_report(self, total_time: float) -> dict[str, Any]: + """Generate the final report.""" + + # Aggregate all metrics + all_ttft = [] + all_latency = [] + all_prompt_tokens = [] + all_cached_tokens = [] + all_completion_tokens = [] + + for metrics in self.round_metrics.values(): + all_ttft.extend(metrics.ttft) + all_latency.extend(metrics.latency) + all_prompt_tokens.extend(metrics.prompt_tokens) + all_cached_tokens.extend(metrics.cached_tokens) + all_completion_tokens.extend(metrics.completion_tokens) + + total_prompt = sum(all_prompt_tokens) + total_cached = sum(all_cached_tokens) + total_completion = sum(all_completion_tokens) + overall_cache_hit_rate = ( + total_cached / total_prompt if total_prompt > 0 else 0.0 + ) + + report = { + "config": { + "num_clients": self.args.num_clients, + "num_rounds": self.args.num_rounds, + "common_prefix_tokens": self.args.common_prefix_tokens, + "prefix_tokens": self.args.prefix_tokens, + "sub_question_tokens": self.args.sub_question_tokens, + "output_tokens": self.args.output_tokens, + "model": self.args.model, + "tag": self.args.tag, + }, + "summary": { + "total_requests": len(all_ttft), + "total_time_sec": total_time, + "throughput_req_per_sec": len(all_ttft) / total_time, + "input_throughput_tok_per_sec": total_prompt / total_time, + "output_throughput_tok_per_sec": total_completion / total_time, + "overall_cache_hit_rate": overall_cache_hit_rate, + "total_prompt_tokens": total_prompt, + "total_cached_tokens": total_cached, + "total_completion_tokens": total_completion, + "ttft": { + "mean": mean(all_ttft) if all_ttft else 0, + "median": median(all_ttft) if all_ttft else 0, + "p90": percentile(all_ttft, 0.9), + "p99": percentile(all_ttft, 0.99), + "max": max(all_ttft) if all_ttft else 0, + }, + "latency": { + "mean": mean(all_latency) if all_latency else 0, + "median": median(all_latency) if all_latency else 0, + "p90": percentile(all_latency, 0.9), + "p99": percentile(all_latency, 0.99), + "max": max(all_latency) if all_latency else 0, + }, + }, + "per_round": {}, + } + + # Per-round breakdown + for round_num, metrics in self.round_metrics.items(): + report["per_round"][f"round_{round_num}"] = { + "num_requests": len(metrics.ttft), + "cache_hit_rate": metrics.cache_hit_rate, + "total_prompt_tokens": sum(metrics.prompt_tokens), + "total_cached_tokens": sum(metrics.cached_tokens), + "ttft_mean": metrics.avg_ttft, + "ttft_p90": percentile(metrics.ttft, 0.9), + "latency_mean": metrics.avg_latency, + } + + return report + + +def print_report(report: dict[str, Any]): + """Print a formatted report.""" + + print("\n" + "=" * 70) + print("MULTI-TURN BENCHMARK RESULTS") + print("=" * 70) + + config = report["config"] + print("\nConfiguration:") + print(f" Model: {config['model']}") + print(f" Tag: {config['tag']}") + print(f" Clients: {config['num_clients']}, Rounds: {config['num_rounds']}") + print(f" Common Prefix: {config['common_prefix_tokens']} tokens") + print(f" Per-Client Prefix: {config['prefix_tokens']} tokens") + + summary = report["summary"] + print("\nOverall Summary:") + print(f" Total Requests: {summary['total_requests']}") + print(f" Total Time: {summary['total_time_sec']:.2f}s") + print(f" Throughput: {summary['throughput_req_per_sec']:.2f} req/s") + print(f" Input Throughput: {summary['input_throughput_tok_per_sec']:.0f} tok/s") + print(f" Output Throughput: {summary['output_throughput_tok_per_sec']:.0f} tok/s") + + print(f"\n CACHE HIT RATE: {summary['overall_cache_hit_rate']:.2%}") + print( + f" Cached: {summary['total_cached_tokens']:,} / {summary['total_prompt_tokens']:,} tokens" + ) + + ttft = summary["ttft"] + print("\n TTFT (Time to First Token):") + print( + f" Mean: {ttft['mean'] * 1000:.1f}ms, Median: {ttft['median'] * 1000:.1f}ms" + ) + print(f" P90: {ttft['p90'] * 1000:.1f}ms, P99: {ttft['p99'] * 1000:.1f}ms") + + latency = summary["latency"] + print("\n Latency (End-to-End):") + print( + f" Mean: {latency['mean'] * 1000:.1f}ms, Median: {latency['median'] * 1000:.1f}ms" + ) + print(f" P90: {latency['p90'] * 1000:.1f}ms, P99: {latency['p99'] * 1000:.1f}ms") + + print("\nPer-Round Breakdown:") + print(f" {'Round':<8} {'TTFT Mean':>12} {'Cache Hit':>12} {'Cached Tokens':>15}") + print(f" {'-' * 8} {'-' * 12} {'-' * 12} {'-' * 15}") + + for round_key in sorted(report["per_round"].keys()): + r = report["per_round"][round_key] + round_num = round_key.replace("round_", "") + print( + f" {round_num:<8} " + f"{r['ttft_mean'] * 1000:>10.1f}ms " + f"{r['cache_hit_rate']:>11.1%} " + f"{r['total_cached_tokens']:>7,}/{r['total_prompt_tokens']:,}" + ) + + print("=" * 70) + + +async def check_server(url: str) -> bool: + """Check if server is available and supports required features.""" + try: + async with ( + aiohttp.ClientSession() as session, + session.get( + url.replace("/v1/chat/completions", "/v1/models"), + timeout=aiohttp.ClientTimeout(total=10), + ) as response, + ): + return response.status == 200 + except Exception: + return False + + +async def main(): + parser = argparse.ArgumentParser( + description="Enhanced Multi-Turn Benchmark for KV Cache Compression" + ) + + # Server settings + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=6789) + parser.add_argument( + "--model", + type=str, + default="/shareddata/MiniMaxAI/MiniMax-M2.7", + help="Model name (as registered in vLLM)", + ) + + # Benchmark settings + parser.add_argument( + "--num-clients", + type=int, + default=16, + help="Number of concurrent clients/conversations", + ) + parser.add_argument( + "--num-rounds", type=int, default=5, help="Number of turns per conversation" + ) + parser.add_argument( + "--max-parallel", type=int, default=32, help="Maximum parallel requests" + ) + + # Token settings + parser.add_argument( + "--common-prefix-tokens", + type=int, + default=1000, + help="Shared prefix tokens across all clients (for prefix caching)", + ) + parser.add_argument( + "--prefix-tokens", + type=int, + default=2000, + help="Unique prefix tokens per client", + ) + parser.add_argument( + "--sub-question-tokens", + type=int, + default=200, + help="Tokens per sub-question in subsequent rounds", + ) + parser.add_argument( + "--output-tokens", type=int, default=100, help="Max output tokens per response" + ) + + # Output settings + parser.add_argument( + "--tag", type=str, default="", help="Tag for this benchmark run" + ) + parser.add_argument( + "--output-dir", + type=str, + default="results/multiturn", + help="Directory for results", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + + args = parser.parse_args() + + # Set random seeds + random.seed(args.seed) + np.random.seed(args.seed) + + # Check server + url = f"http://{args.host}:{args.port}/v1/chat/completions" + print(f"Checking server at {url}...") + if not await check_server(url): + print(f"ERROR: Server not available at {url}") + print("Make sure to start the server with:") + print( + " vllm serve MODEL --enable-prompt-tokens-details --enable-prefix-caching ..." + ) + return + + print("Server is available!") + + # Run benchmark + benchmark = MultiTurnBenchmark(args) + report = await benchmark.run() + + # Print results + print_report(report) + + # Save results + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + tag = args.tag or "default" + output_file = output_dir / f"results_{tag}_{timestamp}.json" + + with open(output_file, "w") as f: + json.dump(report, f, indent=2) + + print(f"\nResults saved to: {output_file}") + + # Also append to JSONL for easy comparison + jsonl_file = output_dir / "all_results.jsonl" + with open(jsonl_file, "a") as f: + f.write(json.dumps({"timestamp": timestamp, "tag": tag, **report}) + "\n") + + print(f"Appended to: {jsonl_file}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmarks/multi_turn_tq/compare_results.py b/benchmarks/multi_turn_tq/compare_results.py new file mode 100755 index 000000000000..3bad52431906 --- /dev/null +++ b/benchmarks/multi_turn_tq/compare_results.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Compare multi-turn benchmark results across different KV cache strategies. + +Usage: + python compare_results.py results/multiturn/results_*.json + python compare_results.py results_baseline.json results_tq4bit.json results_fp8.json +""" + +import argparse +import json +import sys +from pathlib import Path +from typing import Any + + +def load_result(filepath: str) -> dict[str, Any]: + """Load a benchmark result JSON file.""" + with open(filepath) as f: + return json.load(f) + + +def extract_tag(filepath: str) -> str: + """Extract tag from filename like results_tagname_timestamp.json""" + name = Path(filepath).stem + parts = name.replace("results_", "").split("_") + # Remove timestamp (last 2 parts: date_time) + if len(parts) >= 3: + return "_".join(parts[:-2]) + return name + + +def is_enhanced_format(data: dict) -> bool: + """Check if this is from the enhanced benchmark.""" + return "per_round" in data + + +def safe_get(data: dict, *keys, default=None): + """Safely get nested dict values.""" + for key in keys: + if isinstance(data, dict) and key in data: + data = data[key] + else: + return default + return data + + +def format_value(val, fmt=".2f"): + """Format a value, handling None.""" + if val is None: + return "N/A" + if isinstance(val, float): + return f"{val:{fmt}}" + return str(val) + + +def print_comparison_table(results: list[tuple[str, dict]]): + """Print a comparison table of key metrics.""" + + # Check if any results are from enhanced benchmark + has_enhanced = any(is_enhanced_format(d) for _, d in results) + + if has_enhanced: + # Enhanced format metrics (with cache hit rate!) + metrics = [ + ( + "TTFT Mean (ms)", + lambda d: safe_get(d, "summary", "ttft", "mean", default=0) * 1000 + if is_enhanced_format(d) + else safe_get(d, "mean_ttft_ms"), + ), + ( + "TTFT P90 (ms)", + lambda d: safe_get(d, "summary", "ttft", "p90", default=0) * 1000 + if is_enhanced_format(d) + else safe_get(d, "p90_ttft_ms"), + ), + ( + "TTFT P99 (ms)", + lambda d: safe_get(d, "summary", "ttft", "p99", default=0) * 1000 + if is_enhanced_format(d) + else safe_get(d, "p99_ttft_ms"), + ), + ( + "Latency Mean (ms)", + lambda d: safe_get(d, "summary", "latency", "mean", default=0) * 1000 + if is_enhanced_format(d) + else safe_get(d, "mean_e2e_latency_ms"), + ), + ( + "Latency P99 (ms)", + lambda d: safe_get(d, "summary", "latency", "p99", default=0) * 1000 + if is_enhanced_format(d) + else safe_get(d, "p99_e2e_latency_ms"), + ), + ( + "Cache Hit Rate (%)", + lambda d: safe_get(d, "summary", "overall_cache_hit_rate", default=0) + * 100 + if is_enhanced_format(d) + else None, + ), + ( + "Cached Tokens", + lambda d: safe_get(d, "summary", "total_cached_tokens") + if is_enhanced_format(d) + else None, + ), + ( + "Prompt Tokens", + lambda d: safe_get(d, "summary", "total_prompt_tokens") + if is_enhanced_format(d) + else None, + ), + ( + "Input Throughput (tok/s)", + lambda d: safe_get(d, "summary", "input_throughput_tok_per_sec") + if is_enhanced_format(d) + else safe_get(d, "input_throughput"), + ), + ( + "Output Throughput (tok/s)", + lambda d: safe_get(d, "summary", "output_throughput_tok_per_sec") + if is_enhanced_format(d) + else safe_get(d, "output_throughput"), + ), + ( + "Total Requests", + lambda d: safe_get(d, "summary", "total_requests") + if is_enhanced_format(d) + else safe_get(d, "total_requests"), + ), + ( + "Total Duration (s)", + lambda d: safe_get(d, "summary", "total_time_sec") + if is_enhanced_format(d) + else safe_get(d, "total_time_sec"), + ), + ] + else: + # Original vLLM format metrics + metrics = [ + ("TTFT Mean (ms)", lambda d: safe_get(d, "mean_ttft_ms")), + ("TTFT P50 (ms)", lambda d: safe_get(d, "median_ttft_ms")), + ("TTFT P90 (ms)", lambda d: safe_get(d, "p90_ttft_ms")), + ("TTFT P99 (ms)", lambda d: safe_get(d, "p99_ttft_ms")), + ("TPOT Mean (ms)", lambda d: safe_get(d, "mean_tpot_ms")), + ("TPOT P50 (ms)", lambda d: safe_get(d, "median_tpot_ms")), + ("TPOT P99 (ms)", lambda d: safe_get(d, "p99_tpot_ms")), + ("E2E Latency Mean (ms)", lambda d: safe_get(d, "mean_e2e_latency_ms")), + ("E2E Latency P99 (ms)", lambda d: safe_get(d, "p99_e2e_latency_ms")), + ("Input Throughput (tok/s)", lambda d: safe_get(d, "input_throughput")), + ("Output Throughput (tok/s)", lambda d: safe_get(d, "output_throughput")), + ("Total Requests", lambda d: safe_get(d, "total_requests")), + ("Successful Requests", lambda d: safe_get(d, "successful_requests")), + ("Failed Requests", lambda d: safe_get(d, "failed_requests")), + ("Total Duration (s)", lambda d: safe_get(d, "total_time_sec")), + ] + + # Calculate column widths + tags = [tag for tag, _ in results] + metric_width = max(len(m[0]) for m in metrics) + col_widths = [max(12, len(tag) + 2) for tag in tags] + + # Print header + header = f"{'Metric':<{metric_width}}" + for tag, width in zip(tags, col_widths): + header += f" | {tag:>{width}}" + print("=" * len(header)) + print(header) + print("=" * len(header)) + + # Print metrics + for metric_name, extractor in metrics: + row = f"{metric_name:<{metric_width}}" + values = [] + for _, data in results: + val = extractor(data) + values.append(val) + + # Find best value (for latency metrics, lower is better) + # For throughput, higher is better + is_throughput = "throughput" in metric_name.lower() + numeric_vals = [ + v for v in values if v is not None and isinstance(v, (int, float)) + ] + + best_val = None + if numeric_vals: + best_val = max(numeric_vals) if is_throughput else min(numeric_vals) + + for i, (val, width) in enumerate(zip(values, col_widths)): + formatted = format_value(val) + # Highlight best value + if val is not None and val == best_val and len(numeric_vals) > 1: + formatted = f"*{formatted}*" + row += f" | {formatted:>{width}}" + + print(row) + + print("=" * len(header)) + print("* = best value") + + +def calculate_improvement(baseline: dict, comparison: dict) -> dict: + """Calculate percentage improvement from baseline to comparison.""" + improvements = {} + + latency_metrics = [ + "mean_ttft_ms", + "median_ttft_ms", + "p99_ttft_ms", + "mean_tpot_ms", + "p99_tpot_ms", + "mean_e2e_latency_ms", + ] + throughput_metrics = ["input_throughput", "output_throughput"] + + for metric in latency_metrics: + base_val = safe_get(baseline, metric) + comp_val = safe_get(comparison, metric) + if base_val and comp_val and base_val > 0: + # For latency, negative change is improvement + pct = ((base_val - comp_val) / base_val) * 100 + improvements[metric] = pct + + for metric in throughput_metrics: + base_val = safe_get(baseline, metric) + comp_val = safe_get(comparison, metric) + if base_val and comp_val and base_val > 0: + # For throughput, positive change is improvement + pct = ((comp_val - base_val) / base_val) * 100 + improvements[metric] = pct + + return improvements + + +def print_improvement_summary(results: list[tuple[str, dict]]): + """Print improvement summary vs first result (baseline).""" + if len(results) < 2: + return + + baseline_tag, baseline_data = results[0] + + print(f"\n\nImprovement vs Baseline ({baseline_tag})") + print("=" * 60) + + for tag, data in results[1:]: + print(f"\n{tag}:") + improvements = calculate_improvement(baseline_data, data) + + for metric, pct in improvements.items(): + direction = "better" if pct > 0 else "worse" + sign = "+" if pct > 0 else "" + print(f" {metric}: {sign}{pct:.1f}% ({direction})") + + +def print_per_round_comparison(results: list[tuple[str, dict]]): + """Print per-round TTFT and cache hit rate comparison.""" + + # Filter to enhanced format only + enhanced_results = [ + (tag, data) for tag, data in results if is_enhanced_format(data) + ] + if not enhanced_results: + return + + print("\n\nPer-Round Comparison (TTFT and Cache Hit Rate)") + print("=" * 80) + + # Get all round numbers + all_rounds = set() + for _, data in enhanced_results: + all_rounds.update(data.get("per_round", {}).keys()) + rounds = sorted(all_rounds) + + # Print TTFT comparison + print("\nTTFT Mean (ms) by Round:") + header = f"{'Round':<10}" + for tag, _ in enhanced_results: + header += f" | {tag:>15}" + print(header) + print("-" * len(header)) + + for round_key in rounds: + round_num = round_key.replace("round_", "") + row = f"{round_num:<10}" + for _, data in enhanced_results: + round_data = data.get("per_round", {}).get(round_key, {}) + ttft = round_data.get("ttft_mean", 0) * 1000 + row += f" | {ttft:>15.1f}" + print(row) + + # Print Cache Hit Rate comparison + print("\nCache Hit Rate (%) by Round:") + header = f"{'Round':<10}" + for tag, _ in enhanced_results: + header += f" | {tag:>15}" + print(header) + print("-" * len(header)) + + for round_key in rounds: + round_num = round_key.replace("round_", "") + row = f"{round_num:<10}" + for _, data in enhanced_results: + round_data = data.get("per_round", {}).get(round_key, {}) + hit_rate = round_data.get("cache_hit_rate", 0) * 100 + row += f" | {hit_rate:>14.1f}%" + print(row) + + +def main(): + parser = argparse.ArgumentParser(description="Compare multi-turn benchmark results") + parser.add_argument("files", nargs="+", help="Result JSON files to compare") + parser.add_argument( + "--baseline", help="Specify baseline file for improvement calculation" + ) + parser.add_argument( + "--per-round", action="store_true", help="Show per-round breakdown" + ) + args = parser.parse_args() + + if not args.files: + print("No result files provided") + sys.exit(1) + + # Load all results + results = [] + for filepath in args.files: + try: + data = load_result(filepath) + tag = extract_tag(filepath) + results.append((tag, data)) + print(f"Loaded: {filepath} (tag: {tag})") + except Exception as e: + print(f"Error loading {filepath}: {e}") + + if not results: + print("No valid results to compare") + sys.exit(1) + + # Sort results (baseline first if specified) + if args.baseline: + baseline_tag = extract_tag(args.baseline) + results.sort(key=lambda x: (x[0] != baseline_tag, x[0])) + + print("\n") + print_comparison_table(results) + print_improvement_summary(results) + + # Always show per-round for enhanced results + if any(is_enhanced_format(d) for _, d in results): + print_per_round_comparison(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/multi_turn_tq/run_benchmark.sh b/benchmarks/multi_turn_tq/run_benchmark.sh new file mode 100755 index 000000000000..a44804d76d17 --- /dev/null +++ b/benchmarks/multi_turn_tq/run_benchmark.sh @@ -0,0 +1,288 @@ +#!/usr/bin/env bash +# +# Run Enhanced Multi-Turn Benchmark for KV Cache Compression Comparison +# +# This script: +# 1. Starts vLLM server with proper flags (--enable-prompt-tokens-details) +# 2. Runs the enhanced benchmark that captures per-round TTFT and cache hit rate +# 3. Saves results for comparison +# +# Usage: +# ./run_benchmark.sh # Defaults +# ./run_benchmark.sh --kv-cache-dtype turboquant_4bit_nc --tag tq4bit +# ./run_benchmark.sh --kv-cache-dtype auto --tag baseline +# ./run_benchmark.sh --skip-server --tag tq4bit # Use existing server +# +# Quick comparison: +# ./run_benchmark.sh --kv-cache-dtype auto --tag baseline +# ./run_benchmark.sh --kv-cache-dtype turboquant_4bit_nc --tag tq4bit +# python compare_results.py results/multiturn/results_*.json + +set -e + +# ============================================================================ +# CONFIGURABLE PARAMETERS +# ============================================================================ + +# Model Configuration +MODEL="${MODEL:-/shareddata/MiniMaxAI/MiniMax-M2.7}" +TP_SIZE="${TP_SIZE:-2}" +MAX_MODEL_LEN="${MAX_MODEL_LEN:-8192}" +GPU_MEMORY_UTIL="${GPU_MEMORY_UTIL:-0.9}" +PORT="${PORT:-6789}" + +# KV Cache Configuration +KV_CACHE_DTYPE="${KV_CACHE_DTYPE:-auto}" + +# Attention Backend (leave empty to let vLLM auto-select) +ATTENTION_BACKEND="${ATTENTION_BACKEND:-}" + +# Benchmark Configuration +NUM_CLIENTS="${NUM_CLIENTS:-16}" +NUM_ROUNDS="${NUM_ROUNDS:-5}" +MAX_PARALLEL="${MAX_PARALLEL:-32}" + +# Token Configuration +COMMON_PREFIX_TOKENS="${COMMON_PREFIX_TOKENS:-1000}" +PREFIX_TOKENS="${PREFIX_TOKENS:-2000}" +SUB_QUESTION_TOKENS="${SUB_QUESTION_TOKENS:-200}" +OUTPUT_TOKENS="${OUTPUT_TOKENS:-100}" + +# Result tagging +TAG="${TAG:-}" + +# Control flags +SKIP_SERVER="${SKIP_SERVER:-0}" + +# Paths +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +RESULTS_DIR="${SCRIPT_DIR}/results/multiturn" +LOGS_DIR="${SCRIPT_DIR}/logs" + +# ============================================================================ +# PARSE COMMAND LINE ARGUMENTS +# ============================================================================ + +while [[ $# -gt 0 ]]; do + case $1 in + --kv-cache-dtype) + KV_CACHE_DTYPE="$2" + shift 2 + ;; + --tag) + TAG="$2" + shift 2 + ;; + --num-clients) + NUM_CLIENTS="$2" + shift 2 + ;; + --num-rounds) + NUM_ROUNDS="$2" + shift 2 + ;; + --common-prefix) + COMMON_PREFIX_TOKENS="$2" + shift 2 + ;; + --prefix-tokens) + PREFIX_TOKENS="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --skip-server) + SKIP_SERVER=1 + shift + ;; + --help) + echo "Usage: $0 [OPTIONS]" + echo "" + echo "KV Cache Options:" + echo " --kv-cache-dtype TYPE KV cache dtype (auto, fp8_e4m3, turboquant_4bit_nc)" + echo " --tag NAME Tag for result files" + echo "" + echo "Benchmark Options:" + echo " --num-clients N Number of concurrent clients (default: $NUM_CLIENTS)" + echo " --num-rounds N Number of conversation rounds (default: $NUM_ROUNDS)" + echo " --common-prefix N Shared prefix tokens (default: $COMMON_PREFIX_TOKENS)" + echo " --prefix-tokens N Per-client prefix tokens (default: $PREFIX_TOKENS)" + echo "" + echo "Server Options:" + echo " --port N Server port (default: $PORT)" + echo " --skip-server Use existing server (don't restart)" + echo "" + echo "Example - Compare KV cache strategies:" + echo " $0 --kv-cache-dtype auto --tag baseline" + echo " $0 --kv-cache-dtype turboquant_4bit_nc --tag tq4bit" + echo " python compare_results.py results/multiturn/results_*.json" + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# ============================================================================ +# SETUP +# ============================================================================ + +mkdir -p "$RESULTS_DIR" "$LOGS_DIR" + +# Generate tag from KV cache dtype if not provided +if [[ -z "$TAG" ]]; then + TAG="${KV_CACHE_DTYPE//\//_}" + TAG="${TAG//:/_}" +fi + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +echo "============================================" +echo "Enhanced Multi-Turn Benchmark" +echo "============================================" +echo "Model: $MODEL" +echo "TP Size: $TP_SIZE" +echo "KV Cache Dtype: $KV_CACHE_DTYPE" +echo "Port: $PORT" +echo "" +echo "Benchmark Settings:" +echo " Num Clients: $NUM_CLIENTS" +echo " Num Rounds: $NUM_ROUNDS" +echo " Common Prefix: $COMMON_PREFIX_TOKENS tokens" +echo " Per-Client Prefix: $PREFIX_TOKENS tokens" +echo " Sub-Question: $SUB_QUESTION_TOKENS tokens" +echo " Output Tokens: $OUTPUT_TOKENS" +echo "" +echo "Tag: $TAG" +echo "============================================" + +# ============================================================================ +# SERVER MANAGEMENT +# ============================================================================ + +cleanup() { + echo "Cleaning up..." + if [[ -n "${SERVER_PID:-}" ]]; then + kill $SERVER_PID 2>/dev/null || true + wait $SERVER_PID 2>/dev/null || true + fi +} + +wait_for_server() { + local port=$1 + local timeout=${2:-600} + local pid=$3 + echo "Waiting for server on port $port (timeout: ${timeout}s)..." + + local start_time=$(date +%s) + while true; do + # Check if server process is still alive + if [[ -n "$pid" ]] && ! kill -0 $pid 2>/dev/null; then + echo "Server process died unexpectedly!" + return 1 + fi + + if curl -s "http://localhost:${port}/v1/models" > /dev/null 2>&1; then + echo "Server is ready!" + return 0 + fi + + local elapsed=$(($(date +%s) - start_time)) + if [[ $elapsed -ge $timeout ]]; then + echo "Timeout waiting for server" + return 1 + fi + + sleep 5 + done +} + +if [[ "$SKIP_SERVER" -eq 0 ]]; then + trap cleanup EXIT + + # Kill any existing server on the port + echo "Checking for existing server on port $PORT..." + pkill -f "vllm.*--port.*$PORT" 2>/dev/null || true + sleep 2 + + # Build server command with --enable-prompt-tokens-details + SERVER_CMD="vllm serve $MODEL \ + --tensor-parallel-size $TP_SIZE \ + --trust-remote-code \ + --max-model-len $MAX_MODEL_LEN \ + --gpu-memory-utilization $GPU_MEMORY_UTIL \ + --port $PORT \ + --kv-cache-dtype $KV_CACHE_DTYPE \ + --enable-prefix-caching \ + --enable-prompt-tokens-details \ + --enforce-eager" + + # Add attention backend only if specified + if [[ -n "$ATTENTION_BACKEND" ]]; then + SERVER_CMD="$SERVER_CMD --attention-backend $ATTENTION_BACKEND" + fi + + # Add kv-cache-dtype-skip-layers if specified (even if empty string) + if [[ -n "${KV_SKIP_LAYERS+x}" ]]; then + SERVER_CMD="$SERVER_CMD --kv-cache-dtype-skip-layers \"$KV_SKIP_LAYERS\"" + fi + + echo "" + echo "Starting vLLM server..." + echo "Command: $SERVER_CMD" + echo "" + + SERVER_LOG="${LOGS_DIR}/server_${TAG}_${TIMESTAMP}.log" + $SERVER_CMD > "$SERVER_LOG" 2>&1 & + SERVER_PID=$! + + echo "Server PID: $SERVER_PID" + echo "Server log: $SERVER_LOG" + + if ! wait_for_server $PORT 600 $SERVER_PID; then + echo "Failed to start server. Check log: $SERVER_LOG" + tail -50 "$SERVER_LOG" + exit 1 + fi +else + echo "Using existing server on port $PORT" +fi + +# ============================================================================ +# RUN ENHANCED BENCHMARK +# ============================================================================ + +echo "" +echo "Running enhanced multi-turn benchmark..." +echo "" + +BENCHMARK_LOG="${LOGS_DIR}/benchmark_${TAG}_${TIMESTAMP}.log" + +python "${SCRIPT_DIR}/bench_multiturn_enhanced.py" \ + --host localhost \ + --port $PORT \ + --model "$MODEL" \ + --num-clients $NUM_CLIENTS \ + --num-rounds $NUM_ROUNDS \ + --max-parallel $MAX_PARALLEL \ + --common-prefix-tokens $COMMON_PREFIX_TOKENS \ + --prefix-tokens $PREFIX_TOKENS \ + --sub-question-tokens $SUB_QUESTION_TOKENS \ + --output-tokens $OUTPUT_TOKENS \ + --tag "$TAG" \ + --output-dir "$RESULTS_DIR" \ + 2>&1 | tee "$BENCHMARK_LOG" + +echo "" +echo "============================================" +echo "Benchmark Complete!" +echo "============================================" +echo "Results in: $RESULTS_DIR" +echo "Log: $BENCHMARK_LOG" +echo "" +echo "To compare results:" +echo " python ${SCRIPT_DIR}/compare_results.py ${RESULTS_DIR}/results_*.json" diff --git a/tests/evals/gpt_oss/configs/gpt-oss-20b-TQ-t4nc.yaml b/tests/evals/gpt_oss/configs/gpt-oss-20b-TQ-t4nc.yaml index 8def64d371d0..6ff836f1a0c0 100644 --- a/tests/evals/gpt_oss/configs/gpt-oss-20b-TQ-t4nc.yaml +++ b/tests/evals/gpt_oss/configs/gpt-oss-20b-TQ-t4nc.yaml @@ -4,4 +4,6 @@ model_name: "openai/gpt-oss-20b" metric_threshold: 0.56 reasoning_effort: "low" server_args: "--kv-cache-dtype turboquant_4bit_nc --kv_cache_dtype_skip_layers sliding_window" -# --max-model-len 4096 +env: + VLLM_TQ_DECODE_V3: "1" + diff --git a/tests/evals/gpt_oss/test_gpqa_correctness.py b/tests/evals/gpt_oss/test_gpqa_correctness.py index 63188ec40767..c17fb45ddd3a 100644 --- a/tests/evals/gpt_oss/test_gpqa_correctness.py +++ b/tests/evals/gpt_oss/test_gpqa_correctness.py @@ -125,7 +125,7 @@ def test_gpqa_correctness(config_filename): server_args.extend( [ "--trust-remote-code", - "--enforce-eager", + # "--enforce-eager", "--disable-uvicorn-access-log", ] ) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index b2a2295ce461..fd5d3d4e25cb 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -379,10 +379,6 @@ def __init__( # Initialize KV cache quantization attributes _init_kv_cache_quant(self, quant_config, prefix) - # Initialize TurboQuant buffers (Pi, S, centroids) if tq cache dtype - if kv_cache_dtype.startswith("turboquant_"): - self._init_turboquant_buffers(kv_cache_dtype, head_size, prefix) - # for attn backends supporting query quantization self.query_quant = None if ( @@ -403,50 +399,6 @@ def __init__( else GroupShape.PER_TENSOR, ) - def _init_turboquant_buffers( - self, cache_dtype: str, head_size: int, prefix: str - ) -> None: - """Initialize TurboQuant centroids for Lloyd-Max quantization.""" - from vllm.model_executor.layers.quantization.turboquant.centroids import ( - get_centroids, - ) - from vllm.model_executor.layers.quantization.turboquant.config import ( - TurboQuantConfig, - ) - - tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size) - - self.register_buffer( - "_tq_centroids", - get_centroids(head_size, tq_config.centroid_bits), - ) - self._tq_config = tq_config - - # Pre-allocate decode intermediate buffers so model.to(device) moves - # them to GPU *before* the memory profiler runs. Without this the - # profiler gives all free memory to KV cache blocks and the first - # decode OOMs when these buffers are lazily allocated. - _vllm_cfg = get_current_vllm_config() - B = _vllm_cfg.scheduler_config.max_num_seqs - Hq = self.num_heads - S = _vllm_cfg.attention_config.tq_max_kv_splits_for_cuda_graph - D = head_size - self.register_buffer( - "_tq_mid_o_buf", - torch.empty(B, Hq, S, D + 1, dtype=torch.float32), - persistent=False, - ) - self.register_buffer( - "_tq_output_buf", - torch.empty(B, Hq, D, dtype=torch.float32), - persistent=False, - ) - self.register_buffer( - "_tq_lse_buf", - torch.empty(B, Hq, dtype=torch.float32), - persistent=False, - ) - def forward( self, query: torch.Tensor, diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index c145d1ea5259..0f1f7bf37988 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -27,6 +27,10 @@ from vllm.config import get_current_vllm_config from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.turboquant.centroids import ( + get_centroids, +) from vllm.triton_utils import triton from vllm.v1.attention.backend import ( AttentionBackend, @@ -56,6 +60,12 @@ triton_turboquant_decode_attention_v3, ) from vllm.v1.attention.ops.triton_unified_attention import unified_attention +from vllm.v1.worker.workspace import ( + current_workspace_manager, + is_workspace_manager_initialized, +) + +logger = init_logger(__name__) # Opt-in flag to dispatch decode path to the v2 Triton kernel. # v1 remains the default. Set VLLM_TQ_DECODE_V2=1 to enable v2. @@ -69,6 +79,11 @@ if _HAS_FLASH_ATTN: from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func +logger.info_once( + "TurboQuant has flash attn: %s, decode kernel: %s", + _HAS_FLASH_ATTN, + "v3" if _USE_TQ_V3 else "v2" if _USE_TQ_V2 else "v1", +) # Continuation prefill: for small continuation chunks (q_len ≤ threshold), # use the TQ decode kernel directly instead of full-dequant + flash_attn. # do_kv_cache_update already stored all tokens to TQ cache, so the decode @@ -326,9 +341,15 @@ def _ensure_on_device(self, layer, device): H = _build_hadamard(D, str(device)) layer._tq_PiT = H layer._tq_Pi = H + # fp16 copy for rotation in continuation prefill path + layer._tq_Pi_half = H.to(torch.float16) + + # Centroids for Lloyd-Max quantization. + layer._tq_centroids = get_centroids(D, self.tq_config.centroid_bits).to( + device=device, dtype=torch.float32 + ) - c = layer._tq_centroids.to(device=device, dtype=torch.float32) - c_sorted, _ = c.sort() + c_sorted, _ = layer._tq_centroids.sort() layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2 layer._tq_cached = True @@ -574,7 +595,17 @@ def _prefill_attention( # Pre-allocate cu_seqlens for single-request flash_attn calls # to avoid per-request host→device tensor creation. - _cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32) + if not hasattr(self, "_cu_2"): + self._cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32) + # Cache arange on self (avoid per-call kernel launch). + _max_seq = attn_metadata.max_seq_len + _ac: torch.Tensor | None = getattr(self, "_arange_cache", None) + if _ac is None or _ac.shape[0] <= _max_seq: + _ac = torch.arange( + 0, _max_seq + 1, device=query.device, dtype=attn_metadata.seq_lens.dtype + ) + self._arange_cache = _ac + _arange_cache: torch.Tensor = _ac for i in range(num_reqs): q_start = qsl[i] @@ -624,8 +655,8 @@ def _prefill_attention( sinks=self.sinks, ) elif _HAS_FLASH_ATTN: - _cu_2[1] = q_len - cu = _cu_2 + self._cu_2[1] = q_len + cu = self._cu_2 out = flash_attn_varlen_func( q=q_seq, k=k_seq, @@ -659,12 +690,8 @@ def _prefill_attention( if q_len <= _CONTINUATION_DECODE_THRESHOLD: # Fast path: treat each query as a decode request # with incremental seq_lens for causal masking. - synth_seq_lens = torch.arange( - cached_len + 1, - seq_len + 1, - device=query.device, - dtype=attn_metadata.seq_lens.dtype, - ) + # Slice from pre-built arange (no kernel launch) + synth_seq_lens = _arange_cache[cached_len + 1 : seq_len + 1] synth_bt = attn_metadata.block_table[i : i + 1].expand(q_len, -1) if _USE_TQ_V3: out = triton_turboquant_decode_attention_v3( @@ -780,16 +807,17 @@ def _continuation_prefill( # Reuse cached buffers to avoid per-call allocation (~16MB at 8K). alloc_len = math.ceil(cached_len / block_size) * block_size buf_shape = (1, Hk, alloc_len, D) - k_buf = getattr(layer, "_tq_k_dequant_buf", None) - if k_buf is None or k_buf.shape[2] < alloc_len: - k_buf = torch.empty(buf_shape, dtype=torch.float16, device=device) - v_buf = torch.empty(buf_shape, dtype=torch.float16, device=device) - layer._tq_k_dequant_buf = k_buf - layer._tq_v_dequant_buf = v_buf - else: - v_buf = layer._tq_v_dequant_buf - k_cached = k_buf[:, :, :alloc_len, :].zero_() - v_cached = v_buf[:, :, :alloc_len, :].zero_() + # Use WorkspaceManager for dequant buffers. + # Shared across all layers — saves 60× memory at long context. + # Required for CUDA Graph capture (per-layer growth incompatible with CG). + k_buf, v_buf = current_workspace_manager().get_simultaneous( + (buf_shape, torch.float16), + (buf_shape, torch.float16), + ) + # Skip .zero_() — kernel writes all positions up to cached_len, + # and we only read [:cached_len] afterwards. + k_cached = k_buf[:, :, :alloc_len, :] + v_cached = v_buf[:, :, :alloc_len, :] # Opt#3 SoA layout constants (must match store-side computation). key_fp8 = self.tq_config.key_fp8 @@ -840,24 +868,30 @@ def _continuation_prefill( # Inverse-rotate MSE keys back to original space if not self.tq_config.key_fp8: - k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D).float() - k_flat = k_flat @ Pi - k_cached_trim = ( - k_flat.to(torch.float16).reshape(Hk, cached_len, D).transpose(0, 1) - ) # (cached_len, Hk, D) + # fp16 matmul for rotation (2× less bandwidth, uses fp16 tensor cores) + Pi_half = layer._tq_Pi_half + k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D) + k_flat = k_flat @ Pi_half + k_cached_trim = k_flat.reshape(Hk, cached_len, D).transpose( + 0, 1 + ) # (cached_len, Hk, D) — already fp16 else: - k_cached_trim = ( - k_cached[0, :, :cached_len, :].transpose(0, 1).contiguous() + k_cached_trim = k_cached[0, :, :cached_len, :].transpose( + 0, 1 ) # (cached_len, Hk, D) - v_cached_trim = ( - v_cached[0, :, :cached_len, :].transpose(0, 1).contiguous() - ) # (cached_len, Hk, D) + # Skip .contiguous() — the copy into k_full/v_full handles layout + v_cached_trim = v_cached[0, :, :cached_len, :].transpose(0, 1) # Concatenate cached + current chunk K/V (match query dtype) + # Pre-allocate full K/V buffer, copy into slices (no cat alloc) qdtype = query.dtype - k_full = torch.cat([k_cached_trim.to(qdtype), key_chunk], dim=0) - v_full = torch.cat([v_cached_trim.to(qdtype), val_chunk], dim=0) + k_full = torch.empty(seq_len, Hk, D, dtype=qdtype, device=device) + v_full = torch.empty(seq_len, Hk, D, dtype=qdtype, device=device) + k_full[:cached_len] = k_cached_trim.to(qdtype) + k_full[cached_len:] = key_chunk + v_full[:cached_len] = v_cached_trim.to(qdtype) + v_full[cached_len:] = val_chunk # Attention: q_len queries attending to seq_len K/V with causal mask if _HAS_FLASH_ATTN: @@ -907,12 +941,23 @@ def _decode_attention( PiT: torch.Tensor | None = None, layer: torch.nn.Module | None = None, ) -> torch.Tensor: - # Grab cached decode buffers from the layer (lazily allocated). + # Acquire shared decode scratch buffers from WorkspaceManager. + # Layers execute sequentially so one set of buffers is sufficient. + # Falls back to kernel-internal allocation if workspace unavailable. + B = query.shape[0] + D = self.head_size + S = self.max_num_kv_splits + Hq = self.num_heads mid_o_buf = output_buf = lse_buf = None - if layer is not None: - mid_o_buf = getattr(layer, "_tq_mid_o_buf", None) - output_buf = getattr(layer, "_tq_output_buf", None) - lse_buf = getattr(layer, "_tq_lse_buf", None) + if is_workspace_manager_initialized(): + # output_buf in query dtype — matches the in-kernel fp16 cast in stage2. + mid_o_buf, output_buf, lse_buf = ( + current_workspace_manager().get_simultaneous( + ((B, Hq, S, D + 1), torch.float32), + ((B, Hq, D), query.dtype), + ((B, Hq), torch.float32), + ) + ) if _USE_TQ_V3: result = triton_turboquant_decode_attention_v3( diff --git a/vllm/v1/attention/ops/triton_decode_attention.py b/vllm/v1/attention/ops/triton_decode_attention.py index 8118db0da8cf..e1059b47bcba 100644 --- a/vllm/v1/attention/ops/triton_decode_attention.py +++ b/vllm/v1/attention/ops/triton_decode_attention.py @@ -551,6 +551,7 @@ def _fwd_kernel_stage2( NUM_KV_SPLITS: tl.constexpr, BLOCK_DV: tl.constexpr, Lv: tl.constexpr, + OUTPUT_FP16: tl.constexpr = 0, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -587,9 +588,12 @@ def _fwd_kernel_stage2( e_sum = e_sum * old_scale + exp_logic e_max = n_e_max + result = acc / e_sum + if OUTPUT_FP16: + result = result.to(tl.float16) tl.store( o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, - acc / e_sum, + result, mask=mask_d, ) lse_val = e_max + tl.log(e_sum) diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py index f60d5ca0b742..f82e7e2869dd 100644 --- a/vllm/v1/attention/ops/triton_turboquant_decode.py +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -608,10 +608,16 @@ def triton_turboquant_decode_attention( ) # Stage 2: Reduce across KV splits - if output_buf is not None and output_buf.shape[0] >= B: + # Output in query dtype — eliminates float16_copy kernel after stage2 + out_dtype = query.dtype + if ( + output_buf is not None + and output_buf.shape[0] >= B + and output_buf.dtype == out_dtype + ): output = output_buf[:B, :Hq, :D] else: - output = torch.empty(B, Hq, D, dtype=torch.float32, device=device) + output = torch.empty(B, Hq, D, dtype=out_dtype, device=device) if buf_holder is not None: buf_holder._tq_output_buf = output if lse_buf is not None and lse_buf.shape[0] >= B: @@ -636,8 +642,9 @@ def triton_turboquant_decode_attention( NUM_KV_SPLITS=NUM_KV_SPLITS, BLOCK_DV=cfg["BLOCK_D"], Lv=D, + OUTPUT_FP16=1 if out_dtype == torch.float16 else 0, num_warps=4, num_stages=2, ) - return output.to(query.dtype) + return output # already in query dtype