Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 175 additions & 10 deletions vllm/benchmarks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class SampleRequest:

class BenchmarkDataset(ABC):
DEFAULT_SEED = 0
IS_MULTIMODAL = False

def __init__(
self,
Expand Down Expand Up @@ -314,13 +315,15 @@ def sample(
)

vocab_size = tokenizer.vocab_size
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens

prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])

# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(input_len * (1 - range_ratio))
input_high = int(input_len * (1 + range_ratio))
input_low = int(real_input_len * (1 - range_ratio))
input_high = int(real_input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio))

Expand All @@ -343,6 +346,17 @@ def sample(
vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:input_lens[i]]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = prefix_len + int(input_lens[i])
requests.append(
SampleRequest(
Expand Down Expand Up @@ -635,6 +649,7 @@ class ConversationDataset(HuggingFaceDataset):
SUPPORTED_DATASET_PATHS = {
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
}
IS_MULTIMODAL = True

def sample(self,
tokenizer: PreTrainedTokenizerBase,
Expand Down Expand Up @@ -699,6 +714,7 @@ class VisionArenaDataset(HuggingFaceDataset):
"lmarena-ai/vision-arena-bench-v0.1":
lambda x: x["turns"][0][0]["content"]
}
IS_MULTIMODAL = True

def sample(
self,
Expand Down Expand Up @@ -782,6 +798,64 @@ def sample(self,
return sampled_requests


# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation
# -----------------------------------------------------------------------------


class MTBenchDataset(HuggingFaceDataset):
"""
MT-Bench Dataset.
https://huggingface.co/datasets/philschmid/mt-bench

We create a single turn dataset for MT-Bench.
This is similar to Spec decoding benchmark setup in vLLM
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
""" # noqa: E501

DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
SUPPORTED_DATASET_PATHS = {
"philschmid/mt-bench",
}

def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []

for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = item["turns"][0]

# apply template
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)

prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests


# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -856,18 +930,18 @@ def _format_zeta_prompt(
sample: dict,
original_start_marker: str = "<|editable_region_start|>") -> dict:
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
This function formats examples from the NEP dataset
into prompts and expected outputs. It could be

This function formats examples from the NEP dataset
into prompts and expected outputs. It could be
further extended to support more NEP datasets.

Args:
sample: The dataset sample containing events,
sample: The dataset sample containing events,
inputs, and outputs.
original_start_marker: The marker indicating the
start of the editable region. Defaults to
original_start_marker: The marker indicating the
start of the editable region. Defaults to
"<|editable_region_start|>".

Returns:
A dictionary with the formatted prompts and expected outputs.
"""
Expand Down Expand Up @@ -917,3 +991,94 @@ def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
break
self.maybe_oversample_requests(samples, num_requests)
return samples


# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------


class ASRDataset(HuggingFaceDataset):
"""
Dataset class for processing a ASR dataset for transcription.
Tested on the following set:

+----------------+----------------------------------------+--------------------------+-----------------------------+
| Dataset | Domain | Speaking Style | hf-subset |
+----------------+----------------------------------------+--------------------------+-----------------------------+
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
| | | | release3-speaker-adaptation |
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
| AMI | Meetings | Spontaneous | ihm, sdm |
+----------------+----------------------------------------+--------------------------+-----------------------------+

""" # noqa: E501

SUPPORTED_DATASET_PATHS = {
"openslr/librispeech_asr",
"facebook/voxpopuli",
"LIUM/tedlium",
"edinburghcstr/ami",
"speechcolab/gigaspeech",
"kensho/spgispeech",
}

DEFAULT_OUTPUT_LEN = 128
IS_MULTIMODAL = True

# TODO Whisper-specific. Abstract interface when more models are supported.
TRANSCRIPTION_PREAMBLE = (
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>")
skip_long_audios: bool = True

def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
**kwargs,
) -> list:
try:
import librosa
except ImportError as e:
raise ImportError(
"librosa is required for ASRDataset. Please install it "
"using `pip install librosa`.") from e

output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
skipped = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
break
audio = item["audio"]
y, sr = audio["array"], audio["sampling_rate"]
duration_s = librosa.get_duration(y=y, sr=sr)
# Whisper max supported duration
if self.skip_long_audios and duration_s > 30:
skipped += 1
continue

mm_content = {"audio": (y, sr)}
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
if skipped:
logger.warning(
"%d samples discarded from dataset due to"
" their length being greater than"
" what Whisper supports.",
skipped,
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
Loading