diff --git a/tests/benchmarks/test_sampling_params.py b/tests/benchmarks/test_sampling_params.py new file mode 100644 index 000000000000..3bc34a84b377 --- /dev/null +++ b/tests/benchmarks/test_sampling_params.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import pytest + +from vllm.benchmarks.datasets.utils import get_sampling_params +from vllm.tokenizers import TokenizerLike + + +class _FakeTokenizer(TokenizerLike): + """Minimal tokenizer implementing the TokenizerLike protocol + for testing get_sampling_params.""" + + def __init__(self, vocab_size: int = 1000, num_special_tokens: int = 0) -> None: + self._vocab_size = vocab_size + self._num_special_tokens = num_special_tokens + + # -- Properties required by TokenizerLike -- + + @classmethod + def from_pretrained(cls, path_or_repo_id, *a, **kw): # type: ignore[override] + return cls() + + @property + def vocab_size(self) -> int: + return self._vocab_size + + @property + def all_special_tokens(self) -> list[str]: + return [] + + @property + def all_special_ids(self) -> list[int]: + return [] + + @property + def bos_token_id(self) -> int: + return 0 + + @property + def eos_token_id(self) -> int: + return 1 + + @property + def pad_token_id(self) -> int: + return 2 + + @property + def is_fast(self) -> bool: + return False + + @property + def max_token_id(self) -> int: + return self._vocab_size - 1 + + @property + def max_chars_per_token(self) -> int: + return 4 + + @property + def truncation_side(self) -> str: + return "right" + + def num_special_tokens_to_add(self) -> int: + return self._num_special_tokens + + def __call__(self, text, text_pair=None, **kw): # type: ignore[override] + raise NotImplementedError + + def get_vocab(self) -> dict[str, int]: + return {} + + def get_added_vocab(self) -> dict[str, int]: + return {} + + def encode(self, text, **kw) -> list[int]: # type: ignore[override] + raise NotImplementedError + + def apply_chat_template(self, messages, **kw): # type: ignore[override] + raise NotImplementedError + + def convert_tokens_to_ids(self, tokens): # type: ignore[override] + raise NotImplementedError + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + raise NotImplementedError + + def decode(self, ids, skip_special_tokens: bool = False) -> str: # type: ignore[override] + raise NotImplementedError + + def convert_ids_to_tokens( # type: ignore[override] + self, ids, skip_special_tokens: bool = False + ) -> list[str]: + raise NotImplementedError + + +class TestGetSamplingParams: + """Tests for ``get_sampling_params`` in ``vllm.benchmarks.datasets.shared``.""" + + # -- helpers -- + + @staticmethod + def _tok(vocab_size: int = 1000, num_special: int = 0) -> _FakeTokenizer: + return _FakeTokenizer(vocab_size=vocab_size, num_special_tokens=num_special) + + # -- return shape / dtype -- + + def test_returns_three_arrays(self): + rng = np.random.default_rng(0) + result = get_sampling_params(rng, 5, 0.0, 100, 50, self._tok()) + assert len(result) == 3 + for arr in result: + assert isinstance(arr, np.ndarray) + + @pytest.mark.parametrize("n", [1, 10, 100]) + def test_output_length_matches_num_requests(self, n: int): + rng = np.random.default_rng(42) + input_lens, output_lens, offsets = get_sampling_params( + rng, n, 0.0, 64, 32, self._tok() + ) + assert input_lens.shape == (n,) + assert output_lens.shape == (n,) + assert offsets.shape == (n,) + + # -- fixed lengths (range_ratio = 0) -- + + def test_zero_range_ratio_gives_constant_lengths(self): + rng = np.random.default_rng(7) + input_lens, output_lens, _ = get_sampling_params( + rng, 20, 0.0, 128, 64, self._tok() + ) + assert np.all(input_lens == 128) + assert np.all(output_lens == 64) + + def test_special_tokens_subtracted_from_input_only(self): + rng = np.random.default_rng(7) + input_lens, output_lens, _ = get_sampling_params( + rng, 10, 0.0, 100, 50, self._tok(num_special=4) + ) + # real_input_len = 100 - 4 = 96, range_ratio 0 → all 96 + assert np.all(input_lens == 96) + # special tokens are not subtracted from output length + assert np.all(output_lens == 50) + + # -- range ratios -- + + def test_input_range_bounds(self): + rng = np.random.default_rng(0) + ratio = 0.5 + base = 200 + input_lens, _, _ = get_sampling_params( + rng, 500, {"input": ratio, "output": 0.0}, base, 50, self._tok() + ) + lo = int(np.floor(base * (1 - ratio))) + hi = int(np.ceil(base * (1 + ratio))) + assert np.all(input_lens >= lo) + assert np.all(input_lens <= hi) + + def test_output_range_bounds(self): + rng = np.random.default_rng(0) + ratio = 0.3 + base = 100 + _, output_lens, _ = get_sampling_params( + rng, 500, {"input": 0.0, "output": ratio}, 50, base, self._tok() + ) + lo = max(1, int(np.floor(base * (1 - ratio)))) + hi = int(np.ceil(base * (1 + ratio))) + assert np.all(output_lens >= lo) + assert np.all(output_lens <= hi) + + def test_output_low_clamped_to_one(self): + """Even with a high ratio that would push output_low to 0, + the function clamps it to 1.""" + rng = np.random.default_rng(0) + # output_len=1, ratio=0.99 → floor(1*0.01)=0, should clamp to 1 + _, output_lens, _ = get_sampling_params( + rng, 50, {"input": 0.0, "output": 0.99}, 100, 1, self._tok() + ) + assert np.all(output_lens >= 1) + + # -- offsets bounded by vocab_size -- + + @pytest.mark.parametrize("vocab", [100, 32000, 128256]) + def test_offsets_within_vocab(self, vocab: int): + rng = np.random.default_rng(0) + _, _, offsets = get_sampling_params( + rng, 200, 0.0, 64, 32, self._tok(vocab_size=vocab) + ) + assert np.all(offsets >= 0) + assert np.all(offsets < vocab) + + # -- reproducibility -- + + def test_same_seed_same_results(self): + tok = self._tok() + rr = {"input": 0.3, "output": 0.2} + a = get_sampling_params(np.random.default_rng(42), 50, rr, 256, 64, tok) + b = get_sampling_params(np.random.default_rng(42), 50, rr, 256, 64, tok) + for arr_a, arr_b in zip(a, b): + np.testing.assert_array_equal(arr_a, arr_b) + + def test_different_seed_different_results(self): + tok = self._tok() + rr = {"input": 0.3, "output": 0.2} + a = get_sampling_params(np.random.default_rng(0), 50, rr, 256, 64, tok) + b = get_sampling_params(np.random.default_rng(1), 50, rr, 256, 64, tok) + # Extremely unlikely all three arrays match with different seeds + assert not all(np.array_equal(arr_a, arr_b) for arr_a, arr_b in zip(a, b)) + + # -- validation / error paths -- + + @pytest.mark.parametrize("bad_ratio", [-0.1, 1.0, 1.5]) + def test_invalid_input_range_ratio(self, bad_ratio: float): + rng = np.random.default_rng(0) + with pytest.raises(ValueError, match="input_range_ratio"): + get_sampling_params( + rng, 10, {"input": bad_ratio, "output": 0.0}, 100, 50, self._tok() + ) + + @pytest.mark.parametrize("bad_ratio", [-0.1, 1.0, 1.5]) + def test_invalid_output_range_ratio(self, bad_ratio: float): + rng = np.random.default_rng(0) + with pytest.raises(ValueError, match="output_range_ratio"): + get_sampling_params( + rng, 10, {"input": 0.0, "output": bad_ratio}, 100, 50, self._tok() + ) + + def test_invalid_dict_missing_keys(self): + rng = np.random.default_rng(0) + with pytest.raises(ValueError, match="input.*output"): + get_sampling_params(rng, 10, {"input": 0.1}, 100, 50, self._tok()) + + def test_input_len_zero_with_special_tokens(self): + """input_len < num_special_tokens → real_input_len = 0, which is fine + (range [0, 0]).""" + rng = np.random.default_rng(0) + input_lens, _, _ = get_sampling_params( + rng, 5, 0.0, 5, 50, self._tok(num_special=10) + ) + # real_input_len = max(0, 5 - 10) = 0 + assert np.all(input_lens == 0) + + # -- edge cases -- + + def test_single_request(self): + rng = np.random.default_rng(0) + i, o, off = get_sampling_params(rng, 1, 0.0, 100, 50, self._tok()) + assert i.shape == (1,) + assert o.shape == (1,) + assert off.shape == (1,) + + def test_large_num_requests(self): + rng = np.random.default_rng(0) + i, o, off = get_sampling_params(rng, 10_000, 0.5, 512, 128, self._tok()) + assert i.shape == (10_000,) + assert o.shape == (10_000,) + assert off.shape == (10_000,) diff --git a/tests/benchmarks/test_txt_slices_dataset.py b/tests/benchmarks/test_txt_slices_dataset.py new file mode 100644 index 000000000000..7821e9a925a2 --- /dev/null +++ b/tests/benchmarks/test_txt_slices_dataset.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from pathlib import Path + +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import CustomDataset +from vllm.benchmarks.datasets.create_txt_slices_dataset import create_txt_slices_jsonl + + +@pytest.fixture(scope="session") +def hf_tokenizer() -> PreTrainedTokenizerBase: + # Use a small, commonly available tokenizer + return AutoTokenizer.from_pretrained("gpt2") + + +text_content = """ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor +incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud +exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat +nulla pariatur. Excepteur sint occaecat cupidatat non proident, +sunt in culpa qui officia deserunt mollit anim id est laborum. +""" + + +@pytest.mark.benchmark +def test_create_txt_slices_jsonl( + hf_tokenizer: PreTrainedTokenizerBase, tmp_path: Path +) -> None: + """Test that create_txt_slices_jsonl produces valid JSONL for CustomDataset.""" + txt_path = tmp_path / "input.txt" + jsonl_path = tmp_path / "input.txt.jsonl" + + txt_path.write_text(text_content) + + create_txt_slices_jsonl( + input_path=str(txt_path), + output_path=str(jsonl_path), + tokenizer_name="gpt2", + num_prompts=10, + input_len=10, + output_len=10, + ) + + # Verify the JSONL file is valid and has the expected structure + records = [json.loads(line) for line in jsonl_path.read_text().splitlines()] + + assert len(records) == 10 + for record in records: + assert "prompt" in record + assert "output_tokens" in record + assert isinstance(record["prompt"], str) + assert record["output_tokens"] == 10 + + # Verify the JSONL file can be loaded by CustomDataset + dataset = CustomDataset(dataset_path=str(jsonl_path)) + samples = dataset.sample( + tokenizer=hf_tokenizer, + num_requests=10, + output_len=10, + skip_chat_template=True, + ) + + assert len(samples) == 10 + assert all(sample.expected_output_len == 10 for sample in samples) diff --git a/vllm/benchmarks/datasets/__init__.py b/vllm/benchmarks/datasets/__init__.py new file mode 100644 index 000000000000..5d5e172e7b46 --- /dev/null +++ b/vllm/benchmarks/datasets/__init__.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.benchmarks.datasets.datasets import ( + DEFAULT_NUM_PROMPTS, + AIMODataset, + ASRDataset, + BenchmarkDataset, + BlazeditDataset, + BurstGPTDataset, + ConversationDataset, + CustomDataset, + CustomMMDataset, + HuggingFaceDataset, + InstructCoderDataset, + MLPerfDataset, + MMStarDataset, + MMVUDataset, + MTBenchDataset, + MultiModalConversationDataset, + NextEditPredictionDataset, + PrefixRepetitionRandomDataset, + RandomDataset, + RandomDatasetForReranking, + RandomMultiModalDataset, + SampleRequest, + ShareGPTDataset, + SonnetDataset, + SpecBench, + VisionArenaDataset, + add_dataset_parser, + add_random_dataset_base_args, + add_random_multimodal_dataset_args, + gen_prompt_decode_to_target_len, + get_samples, + is_valid_sequence, + lora_path_on_disk, + lora_tokenizer_cache, + process_image, + process_video, + zeta_prompt, +) +from vllm.benchmarks.datasets.utils import RangeRatio + +__all__ = [ + "DEFAULT_NUM_PROMPTS", + "AIMODataset", + "ASRDataset", + "BenchmarkDataset", + "BlazeditDataset", + "BurstGPTDataset", + "ConversationDataset", + "CustomDataset", + "CustomMMDataset", + "HuggingFaceDataset", + "InstructCoderDataset", + "MLPerfDataset", + "MMStarDataset", + "MMVUDataset", + "MTBenchDataset", + "MultiModalConversationDataset", + "NextEditPredictionDataset", + "PrefixRepetitionRandomDataset", + "RandomDataset", + "RandomDatasetForReranking", + "RandomMultiModalDataset", + "SampleRequest", + "ShareGPTDataset", + "SonnetDataset", + "SpecBench", + "VisionArenaDataset", + "add_dataset_parser", + "add_random_dataset_base_args", + "add_random_multimodal_dataset_args", + "gen_prompt_decode_to_target_len", + "get_samples", + "is_valid_sequence", + "lora_path_on_disk", + "lora_tokenizer_cache", + "process_image", + "process_video", + "RangeRatio", + "zeta_prompt", +] diff --git a/vllm/benchmarks/datasets/create_txt_slices_dataset.py b/vllm/benchmarks/datasets/create_txt_slices_dataset.py new file mode 100644 index 000000000000..3f7c5028a205 --- /dev/null +++ b/vllm/benchmarks/datasets/create_txt_slices_dataset.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Convert a plain-text file (local path or URL) into a JSONL dataset +compatible with ``CustomDataset`` (``--dataset-name custom``), by +randomly slicing the tokenized text into prompts. + +Each line of the output JSONL contains a ``prompt`` (decoded from a random +slice of the tokenized source text) and an ``output_tokens`` count. + +Usage +----- +:: + + python -m vllm.benchmarks.datasets.create_txt_slices_dataset \\ + --input sonnet.txt \\ + --output sonnet_dataset.jsonl \\ + --tokenizer gpt2 \\ + --num-prompts 1000 \\ + --input-len 1024 \\ + --output-len 128 + +The resulting JSONL file can then be used with the serving benchmark:: + + python -m vllm.benchmarks.serve \\ + --dataset-name custom \\ + --dataset-path sonnet_dataset.jsonl \\ + ... +""" + +from __future__ import annotations + +import argparse +import json +import logging +import random +import urllib.request + +import numpy as np +from transformers import AutoTokenizer + +from vllm.benchmarks.datasets.utils import RangeRatio, get_sampling_params + +logger = logging.getLogger(__name__) + + +def load_text(path: str) -> str: + """Load text from a local file or URL.""" + if path.startswith(("http://", "https://")): + with urllib.request.urlopen(path) as response: + return response.read().decode("utf-8") + with open(path, encoding="utf-8") as f: + return f.read() + + +def create_txt_slices_jsonl( + *, + input_path: str, + output_path: str, + tokenizer_name: str, + num_prompts: int, + input_len: int, + output_len: int, + range_ratio: RangeRatio = 0.0, + seed: int = 0, + trust_remote_code: bool = False, +) -> None: + """Read *input_path*, slice it into prompts, and write JSONL to + *output_path*.""" + + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code + ) + + text = load_text(input_path) + if not text: + raise ValueError("The text file is empty and cannot be sampled from.") + + token_ids = tokenizer(text, add_special_tokens=False).input_ids + if not token_ids: + raise ValueError("Tokenizing the text produced zero tokens; cannot sample.") + + rng_np = np.random.default_rng(seed) + rng_py = random.Random(seed) + + input_lens, output_lens, _ = get_sampling_params( + rng_np, + num_prompts, + range_ratio, + input_len, + output_len, + tokenizer, + ) + + num_available_tokens = len(token_ids) + + records: list[dict[str, object]] = [] + for i in range(num_prompts): + req_input_len = int(input_lens[i]) + req_output_len = int(output_lens[i]) + + # Randomly select a start position and slice with cycling + start_pos = rng_py.randint(0, num_available_tokens - 1) + prompt_token_ids = [ + token_ids[(start_pos + j) % num_available_tokens] + for j in range(req_input_len) + ] + prompt = tokenizer.decode(prompt_token_ids, skip_special_tokens=False) + + records.append({"prompt": prompt, "output_tokens": req_output_len}) + + with open(output_path, "w", encoding="utf-8") as f: + for record in records: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + logger.info( + "Wrote %d prompts to %s", + len(records), + output_path, + ) + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Convert a plain-text file into a JSONL dataset " + "for CustomDataset (--dataset-name custom).", + ) + parser.add_argument( + "--input", + required=True, + help="Path or URL to the source text file.", + ) + parser.add_argument( + "--output", + required=True, + help="Path for the output JSONL file.", + ) + parser.add_argument( + "--tokenizer", + required=True, + help="HuggingFace tokenizer name or path.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompt samples to generate (default: 1000).", + ) + parser.add_argument( + "--input-len", + type=int, + default=1024, + help="Target number of input tokens per prompt (default: 1024).", + ) + parser.add_argument( + "--output-len", + type=int, + default=128, + help="Target number of output tokens per prompt (default: 128).", + ) + parser.add_argument( + "--range-ratio", + type=str, + default="0.0", + help="Range ratio for input/output length sampling (default: 0.0). " + "A single float applies to both ISL and OSL. " + 'A JSON dict like \'{"input": 0.3, "output": 0.5}\' sets them ' + "independently. Values must be in [0, 1).", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed for reproducibility (default: 0).", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from HuggingFace.", + ) + + args = parser.parse_args(argv) + + logging.basicConfig(level=logging.INFO) + + # Parse --range-ratio: try float first, then JSON dict. + range_ratio: RangeRatio + try: + range_ratio = float(args.range_ratio) + except ValueError: + import json as _json + + range_ratio = _json.loads(args.range_ratio) + + create_txt_slices_jsonl( + input_path=args.input, + output_path=args.output, + tokenizer_name=args.tokenizer, + num_prompts=args.num_prompts, + input_len=args.input_len, + output_len=args.output_len, + range_ratio=range_ratio, + seed=args.seed, + trust_remote_code=args.trust_remote_code, + ) + + +if __name__ == "__main__": + main() diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets/datasets.py similarity index 96% rename from vllm/benchmarks/datasets.py rename to vllm/benchmarks/datasets/datasets.py index dd71762b5ba7..d7ba9d8787ab 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets/datasets.py @@ -22,8 +22,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterator, Mapping from contextlib import suppress -from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import cache from io import BytesIO from tempfile import NamedTemporaryFile @@ -35,6 +34,11 @@ from PIL import Image from typing_extensions import deprecated +from vllm.benchmarks.datasets.utils import ( + RangeRatio, + _resolve_range_ratios, + get_sampling_params, +) from vllm.inputs import MultiModalDataDict from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path @@ -60,10 +64,6 @@ DEFAULT_NUM_PROMPTS = 1000 -# ----------------------------------------------------------------------------- -# Data Classes -# ----------------------------------------------------------------------------- - @dataclass class SampleRequest: @@ -71,9 +71,9 @@ class SampleRequest: Represents a single inference request for benchmarking. """ - prompt: str | list[str] + prompt: str | list[str] | list[dict] prompt_len: int - expected_output_len: int + expected_output_len: int | None multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None lora_request: LoRARequest | None = None request_id: str | None = None @@ -110,7 +110,7 @@ def __init__( # default seed. self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED self.disable_shuffle = disable_shuffle - self.data = None + self.data: Any | None = None def apply_multimodal_chat_transformation( self, @@ -249,6 +249,7 @@ def sample( num_requests: int, request_id_prefix: str = "", no_oversample: bool = False, + **kwargs, ) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -296,8 +297,10 @@ def maybe_oversample_requests( needed = num_requests - len(requests) additional = [] for i in range(needed): - req = deepcopy(random.choice(requests)) - req.request_id = request_id_prefix + str(len(requests) + i) + req = replace( + random.choice(requests), + request_id=request_id_prefix + str(len(requests) + i), + ) additional.append(req) requests.extend(additional) logger.info("Oversampled requests to reach %d total samples.", num_requests) @@ -533,7 +536,7 @@ def sample( request_id_prefix: str = "", no_oversample: bool = False, prefix_len: int = DEFAULT_PREFIX_LEN, - range_ratio: float = DEFAULT_RANGE_RATIO, + range_ratio: RangeRatio = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, batchsize: int = 1, @@ -542,24 +545,33 @@ def sample( lora_assignment: str = "random", **kwargs, ) -> list[SampleRequest]: - # validate total input tokens (prefix + sampled) is at least 1. + resolved_input_rr, _ = _resolve_range_ratios(range_ratio) + num_special = int(tokenizer.num_special_tokens_to_add()) real_input_len = max(0, int(input_len) - num_special) - min_sampled_input = math.floor(real_input_len * (1.0 - float(range_ratio))) + min_sampled_input = math.floor( + real_input_len * (1.0 - float(resolved_input_rr)) + ) min_total_input = int(prefix_len) + min_sampled_input if min_total_input < 1: raise ValueError( "--random-input-len is too small: with tokenizer special " - f"tokens {num_special} and --random-range-ratio {range_ratio}, " + f"tokens {num_special} and " + f"input range ratio {resolved_input_rr}, " "the minimum possible total input tokens (prefix + sampled) is " f"{min_total_input}. Increase --random-input-len and/or " - "--random-prefix-len, or decrease --random-range-ratio so that " - "prefix_len + floor(max(0, random_input_len - num_special)) " - "* (1 - range_ratio) >= 1." - ) - - input_lens, output_lens, offsets = self.get_sampling_params( - num_requests, range_ratio, input_len, output_len, tokenizer + "--random-prefix-len, or decrease the input range ratio " + "so that prefix_len + floor(max(0, random_input_len - " + "num_special)) * (1 - input_range_ratio) >= 1." + ) + + input_lens, output_lens, offsets = get_sampling_params( + self._rng, + num_requests, + range_ratio, + input_len, + output_len, + tokenizer, ) vocab_size = tokenizer.vocab_size @@ -661,55 +673,6 @@ def get_prefix( ) return adjusted_tokens - def get_sampling_params( - self, - num_requests: int, - range_ratio: float, - input_len: int, - output_len: int, - tokenizer: TokenizerLike, - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Get the sampling parameters for the dataset. - """ - # Enforce range_ratio < 1 - if not (0.0 <= range_ratio < 1.0): - raise ValueError("range_ratio must be in [0, 1).") - num_special_tokens = int(tokenizer.num_special_tokens_to_add()) - real_input_len = max(0, int(input_len) - num_special_tokens) - # Bounds use floor for low and ceil for high - input_low = math.floor(real_input_len * (1 - range_ratio)) - input_high = math.ceil(real_input_len * (1 + range_ratio)) - output_low = math.floor(output_len * (1 - range_ratio)) - output_high = math.ceil(output_len * (1 + range_ratio)) - # Ensure the lower bound for output length is at least 1 to - # prevent sampling 0 tokens. - output_low = max(output_low, 1) - output_high = max(output_high, 1) - - if input_low > input_high: - raise ValueError( - f"Invalid input sampling interval: low={input_low} > high={input_high}" - ) - if output_low > output_high: - raise ValueError( - "Invalid output sampling interval: " - f"low={output_low} > high={output_high}" - ) - - logger.info( - "Sampling input_len from [%s, %s] and output_len from [%s, %s]", - input_low, - input_high, - output_low, - output_high, - ) - - input_lens = self._rng.integers(input_low, input_high + 1, size=num_requests) - output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests) - offsets = self._rng.integers(0, tokenizer.vocab_size, size=num_requests) - return input_lens, output_lens, offsets - def generate_token_sequence( self, *, @@ -776,8 +739,11 @@ def sample( tokenizer: TokenizerLike, num_requests: int, request_id_prefix: str = "", - range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + no_oversample: bool = False, + prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, + range_ratio: RangeRatio = RandomDataset.DEFAULT_RANGE_RATIO, input_len: int = RandomDataset.DEFAULT_INPUT_LEN, + output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN, batchsize: int = 1, is_reranker: bool = True, **kwargs, @@ -786,8 +752,13 @@ def sample( query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len - query_lens, _, query_offsets = self.get_sampling_params( - 1, range_ratio, query_len_param, 0, tokenizer + query_lens, _, query_offsets = get_sampling_params( + self._rng, + 1, + range_ratio, + query_len_param, + 0, + tokenizer, ) query_len = int(query_lens[0]) @@ -800,8 +771,13 @@ def sample( else: doc_len_param = input_len - query_len - n_sep_tokens - doc_lens, _, doc_offsets = self.get_sampling_params( - num_requests, range_ratio, doc_len_param, 0, tokenizer + doc_lens, _, doc_offsets = get_sampling_params( + self._rng, + num_requests, + range_ratio, + doc_len_param, + 0, + tokenizer, ) vocab_size = tokenizer.vocab_size @@ -1175,9 +1151,10 @@ def sample( request_id_prefix: str = "", no_oversample: bool = False, prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, - range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + range_ratio: RangeRatio = RandomDataset.DEFAULT_RANGE_RATIO, input_len: int = RandomDataset.DEFAULT_INPUT_LEN, output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN, + batchsize: int = 1, limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, @@ -1187,9 +1164,18 @@ def sample( enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, **kwargs, ) -> list[SampleRequest]: - # Get the sampling parameters for the dataset - input_lens, output_lens, offsets = self.get_sampling_params( - num_requests, range_ratio, input_len, output_len, tokenizer + if batchsize != 1: + raise NotImplementedError( + "batchsize > 1 is not supported for RandomMultiModalDataset." + ) + + input_lens, output_lens, offsets = get_sampling_params( + self._rng, + num_requests, + range_ratio, + input_len, + output_len, + tokenizer, ) ( @@ -1326,16 +1312,16 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, lora_path: str | None = None, max_loras: int | None = None, output_len: int | None = None, enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - no_oversample: bool = False, lora_assignment: str = "random", **kwargs, - ) -> list: - samples: list = [] + ) -> list[SampleRequest]: + samples: list[SampleRequest] = [] ind = 0 for entry in self.data: if len(samples) >= num_requests: @@ -1449,8 +1435,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser): type=str, default=None, action=_ValidateDatasetArgs, - help="Path to the sharegpt/sonnet dataset. " - "Or the huggingface dataset ID if using HF dataset.", + help="Path to the sharegpt/sonnet dataset or the HF dataset ID if " + "using HF dataset.", ) parser.add_argument( "--no-oversample", @@ -1648,12 +1634,12 @@ def add_random_dataset_base_args( ) parser_or_group.add_argument( "--random-range-ratio", - type=float, - default=0.0, + type=str, + default="0.0", help="Range ratio for sampling input/output length, " - "used only for random sampling. Must be in the range [0, 1) to define " - "a symmetric sampling range" - "[length * (1 - range_ratio), length * (1 + range_ratio)].", + "used only for random sampling. A single float applies to both " + 'ISL and OSL. A JSON dict like \'{"input": 0.3, "output": 0.5}\' ' + "sets them independently. Values must be in [0, 1).", ) parser_or_group.add_argument( "--random-prefix-len", @@ -1786,10 +1772,25 @@ def normalize(d: dict) -> dict[tuple[int, int, int], float]: ) +def _parse_range_ratio(value: str) -> RangeRatio: + """Parse a ``--random-range-ratio`` CLI string. + + Accepts either a plain float (``"0.3"``) or a JSON dict + (``'{"input": 0.3, "output": 0.5}'``). + """ + try: + return float(value) + except ValueError: + return json.loads(value) + + def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]: if not hasattr(args, "request_id_prefix"): args.request_id_prefix = "" + if hasattr(args, "random_range_ratio") and isinstance(args.random_range_ratio, str): + args.random_range_ratio = _parse_range_ratio(args.random_range_ratio) + if args.dataset_name == "custom": dataset = CustomDataset( dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle @@ -2120,7 +2121,7 @@ def load_data(self) -> None: # This will be the standardized format which load_data() # has to convert into depending on the filetype of dataset_path. # sample() will assume this standardized format of self.data - self.data = [] + self.data: list[dict] = [] # Load the JSONL file if self.dataset_path.endswith(".jsonl"): @@ -2149,15 +2150,15 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, lora_path: str | None = None, max_loras: int | None = None, output_len: int | None = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, - request_id_prefix: str = "", - no_oversample: bool = False, **kwargs, - ) -> list: + ) -> list[SampleRequest]: # load all data if needed self.num_available_samples = len(self.data) if num_requests <= 0: @@ -2168,7 +2169,7 @@ def sample( num_requests, ) - sampled_requests = [] + sampled_requests: list[SampleRequest] = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break @@ -2252,7 +2253,7 @@ def sample( request_id_prefix: str = "", no_oversample: bool = False, **kwargs, - ) -> list: + ) -> list[SampleRequest]: # load all data if needed self.num_available_samples = len(self.data) if num_requests <= 0: @@ -2340,9 +2341,13 @@ def load_data(self) -> None: if not getattr(self, "disable_shuffle", False): random.shuffle(self.data) - def sample(self, **kwargs) -> list: + def sample( + **kwargs, + ) -> list[SampleRequest]: # leverage CustomDataset sample - return super().sample(**kwargs) + return super().sample( + **kwargs, + ) # ----------------------------------------------------------------------------- @@ -2381,14 +2386,14 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, prefix_len: int = DEFAULT_PREFIX_LEN, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, - request_id_prefix: str = "", - no_oversample: bool = False, **kwargs, - ) -> list: + ) -> list[SampleRequest]: # Calculate average token length for a poem line. tokenized_lines = [tokenizer(line).input_ids for line in self.data] avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines) @@ -2411,7 +2416,7 @@ def sample( num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0) prefix_lines = self.data[:num_prefix_lines] - samples = [] + samples: list[SampleRequest] = [] ind = 0 while len(samples) < num_requests: extra_lines = random.choices( @@ -2482,11 +2487,11 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, - max_loras: int | None = None, - lora_path: str | None = None, request_id_prefix: str = "", no_oversample: bool = False, lora_assignment: str = "random", + max_loras: int | None = None, + lora_path: str | None = None, **kwargs, ) -> list[SampleRequest]: samples = [] @@ -2574,15 +2579,15 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, - output_len: int | None = None, - enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, + output_len: int | None = None, + enable_multimodal_chat: bool = False, **kwargs, - ) -> list: + ) -> list[SampleRequest]: # Filter examples with at least 2 conversations filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) - sampled_requests = [] + sampled_requests: list[SampleRequest] = [] ind = 0 dynamic_output = output_len is None @@ -2634,15 +2639,15 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, - output_len: int | None = None, - enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, + output_len: int | None = None, + enable_multimodal_chat: bool = False, **kwargs, - ) -> list: + ) -> list[SampleRequest]: # Filter examples with at least 2 conversations filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) - sampled_requests = [] + sampled_requests: list[SampleRequest] = [] ind = 0 dynamic_output = output_len is None @@ -2703,12 +2708,12 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, - output_len: int | None = None, - enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, + output_len: int | None = None, + enable_multimodal_chat: bool = False, **kwargs, - ) -> list: + ) -> list[SampleRequest]: parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.hf_name) if parser_fn is None: raise ValueError(f"Unsupported dataset path: {self.hf_name}") @@ -2753,9 +2758,11 @@ class MMVUDataset(HuggingFaceDataset): DEFAULT_OUTPUT_LEN = 128 SUPPORTED_DATASET_PATHS = { - "yale-nlp/MMVU": lambda x: x["question"] - + " " - + (" ".join(f"{k}.{v}" for k, v in x["choices"].items())), + "yale-nlp/MMVU": lambda x: ( + x["question"] + + " " + + (" ".join(f"{k}.{v}" for k, v in x["choices"].items())) + ), } def __init__(self, **kwargs) -> None: @@ -2770,12 +2777,12 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, - output_len: int | None = None, - enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, + output_len: int | None = None, + enable_multimodal_chat: bool = False, **kwargs, - ) -> list: + ) -> list[SampleRequest]: parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.hf_name) if parser_fn is None: raise ValueError(f"Unsupported dataset path: {self.hf_name}") @@ -2838,15 +2845,15 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, output_len: int | None = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, - request_id_prefix: str = "", - no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - sampled_requests = [] + sampled_requests: list[SampleRequest] = [] for i, prompt in enumerate(self.sample_prompts(n=num_requests)): # apply template if not skip_chat_template: @@ -2903,15 +2910,15 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, output_len: int | None = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, - request_id_prefix: str = "", - no_oversample: bool = False, **kwargs, - ) -> list: + ) -> list[SampleRequest]: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - sampled_requests = [] + sampled_requests: list[SampleRequest] = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: @@ -2976,7 +2983,7 @@ def sample( min_distance: float = 0.0, max_distance: float = 1.0, **kwargs, - ) -> list: + ) -> list[SampleRequest]: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] @@ -3050,12 +3057,12 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, - output_len: int | None = None, request_id_prefix: str = "", no_oversample: bool = False, + output_len: int | None = None, **kwargs, - ) -> list: - sampled_requests = [] + ) -> list[SampleRequest]: + sampled_requests: list[SampleRequest] = [] ind = 0 dynamic_output = output_len is None @@ -3228,18 +3235,18 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, - output_len: int | None = None, request_id_prefix: str = "", no_oversample: bool = False, + output_len: int | None = None, **kwargs, - ) -> list: + ) -> list[SampleRequest]: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN if "openai" in getattr(tokenizer, "name_or_path", ""): prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" else: prompt = "" prompt_len = len(tokenizer(prompt).input_ids) - sampled_requests = [] + sampled_requests: list[SampleRequest] = [] ind = 0 skipped = 0 asr_min_audio_len_sec = kwargs.get("asr_min_audio_len_sec") @@ -3326,9 +3333,9 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, - output_len: int | None = None, request_id_prefix: str = "", no_oversample: bool = False, + output_len: int | None = None, **kwargs, ) -> list[SampleRequest]: # Force dynamic output length based on reference completion. @@ -3405,12 +3412,12 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, prefix_len: int = DEFAULT_PREFIX_LEN, suffix_len: int = DEFAULT_SUFFIX_LEN, num_prefixes: int = DEFAULT_NUM_PREFIXES, output_len: int = DEFAULT_OUTPUT_LEN, - request_id_prefix: str = "", - no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: vocab_size = tokenizer.vocab_size @@ -3421,7 +3428,7 @@ def sample( f"to num_prefixes ({num_prefixes})" ) - def _generate_exact_length_tokens(target_length: int) -> list[int]: + def _generate_exact_length_tokens(target_length: int) -> tuple[list[int], int]: """Generate tokens that decode and re-encode to exactly target_length.""" # Generate random tokens @@ -3491,10 +3498,10 @@ def sample( self, tokenizer: TokenizerLike, num_requests: int, - output_len: int | None = None, - enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, + output_len: int | None = None, + enable_multimodal_chat: bool = False, **kwargs, ) -> list[SampleRequest]: # If --hf-output-len is not set, use the default output length. @@ -3516,6 +3523,7 @@ def sample( # if enable_multimodal_chat is False). prompt_len = len(tokenizer(question_text).input_ids) + prompt: str | list[dict] if enable_multimodal_chat: # If multimodal content should be embedded in the chat message, # convert to [{"role":"user","content":[...]}] diff --git a/vllm/benchmarks/datasets/utils.py b/vllm/benchmarks/datasets/utils.py new file mode 100644 index 000000000000..bc5a4340dd62 --- /dev/null +++ b/vllm/benchmarks/datasets/utils.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared utilities for benchmark dataset sampling. +""" + +import logging +import math + +import numpy as np + +from vllm.tokenizers import TokenizerLike + +logger = logging.getLogger(__name__) + +# Type alias: a single float applies to both ISL and OSL; a dict allows +# specifying them independently via ``{"input": …, "output": …}``. +RangeRatio = float | dict[str, float] + + +def _resolve_range_ratios( + range_ratio: RangeRatio, +) -> tuple[float, float]: + """Return ``(input_range_ratio, output_range_ratio)`` from *range_ratio*. + + *range_ratio* is either a single float (used for both input and output) + or a dict with ``"input"`` and ``"output"`` keys. + """ + if isinstance(range_ratio, dict): + try: + return float(range_ratio["input"]), float(range_ratio["output"]) + except KeyError as exc: + raise ValueError( + "When range_ratio is a dict it must contain 'input' and " + f"'output' keys, got: {sorted(range_ratio)}" + ) from exc + ratio = float(range_ratio) + return ratio, ratio + + +def get_sampling_params( + rng: np.random.Generator, + num_requests: int, + range_ratio: RangeRatio, + input_len: int, + output_len: int, + tokenizer: TokenizerLike, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Sample per-request input/output token lengths and vocab offsets. + + Lengths are drawn uniformly from integer ranges around the configured + means, controlled by *range_ratio*. It may be a single ``float`` + (applied to both input and output) or a ``dict`` with ``"input"`` and + ``"output"`` keys for independent control. + + Tokenizer special tokens are subtracted from ``input_len`` before + computing the sampling interval. + + Returns: + (input_lens, output_lens, offsets) – three 1-D ``np.ndarray`` of + shape ``(num_requests,)``. + """ + input_range_ratio, output_range_ratio = _resolve_range_ratios(range_ratio) + + if not (0.0 <= input_range_ratio < 1.0): + raise ValueError("input_range_ratio must be in [0, 1).") + if not (0.0 <= output_range_ratio < 1.0): + raise ValueError("output_range_ratio must be in [0, 1).") + num_special_tokens = int(tokenizer.num_special_tokens_to_add()) + real_input_len = max(0, int(input_len) - num_special_tokens) + input_low = math.floor(real_input_len * (1 - input_range_ratio)) + input_high = math.ceil(real_input_len * (1 + input_range_ratio)) + output_low = math.floor(output_len * (1 - output_range_ratio)) + output_high = math.ceil(output_len * (1 + output_range_ratio)) + # Ensure the lower bound for output length is at least 1 to + # prevent sampling 0 tokens. + output_low = max(output_low, 1) + output_high = max(output_high, 1) + + if input_low > input_high: + raise ValueError( + f"Invalid input sampling interval: low={input_low} > high={input_high}" + ) + if output_low > output_high: + raise ValueError( + f"Invalid output sampling interval: low={output_low} > high={output_high}" + ) + + logger.info( + "Sampling input_len from [%s, %s] and output_len from [%s, %s]", + input_low, + input_high, + output_low, + output_high, + ) + + input_lens = rng.integers(input_low, input_high + 1, size=num_requests) + output_lens = rng.integers(output_low, output_high + 1, size=num_requests) + offsets = rng.integers(0, tokenizer.vocab_size, size=num_requests) + return input_lens, output_lens, offsets