Skip to content

Commit e56f44d

Browse files
authored
Support datasets in vllm bench serve and sync with benchmark_[serving,datasets].py (#18566)
1 parent e0cbad4 commit e56f44d

File tree

3 files changed

+692
-101
lines changed

3 files changed

+692
-101
lines changed

vllm/benchmarks/datasets.py

Lines changed: 175 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class SampleRequest:
6262

6363
class BenchmarkDataset(ABC):
6464
DEFAULT_SEED = 0
65+
IS_MULTIMODAL = False
6566

6667
def __init__(
6768
self,
@@ -316,13 +317,15 @@ def sample(
316317
)
317318

318319
vocab_size = tokenizer.vocab_size
320+
num_special_tokens = tokenizer.num_special_tokens_to_add()
321+
real_input_len = input_len - num_special_tokens
319322

320323
prefix_token_ids = (np.random.randint(
321324
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
322325

323326
# New sampling logic: [X * (1 - b), X * (1 + b)]
324-
input_low = int(input_len * (1 - range_ratio))
325-
input_high = int(input_len * (1 + range_ratio))
327+
input_low = int(real_input_len * (1 - range_ratio))
328+
input_high = int(real_input_len * (1 + range_ratio))
326329
output_low = int(output_len * (1 - range_ratio))
327330
output_high = int(output_len * (1 + range_ratio))
328331

@@ -345,6 +348,17 @@ def sample(
345348
vocab_size).tolist()
346349
token_sequence = prefix_token_ids + inner_seq
347350
prompt = tokenizer.decode(token_sequence)
351+
# After decoding the prompt we have to encode and decode it again.
352+
# This is done because in some cases N consecutive tokens
353+
# give a string tokenized into != N number of tokens.
354+
# For example for GPT2Tokenizer:
355+
# [6880, 6881] -> ['Ġcalls', 'here'] ->
356+
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
357+
# To avoid uncontrolled change of the prompt length,
358+
# the encoded sequence is truncated before being decode again.
359+
re_encoded_sequence = tokenizer.encode(
360+
prompt, add_special_tokens=False)[:input_lens[i]]
361+
prompt = tokenizer.decode(re_encoded_sequence)
348362
total_input_len = prefix_len + int(input_lens[i])
349363
requests.append(
350364
SampleRequest(
@@ -637,6 +651,7 @@ class ConversationDataset(HuggingFaceDataset):
637651
SUPPORTED_DATASET_PATHS = {
638652
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
639653
}
654+
IS_MULTIMODAL = True
640655

641656
def sample(self,
642657
tokenizer: PreTrainedTokenizerBase,
@@ -701,6 +716,7 @@ class VisionArenaDataset(HuggingFaceDataset):
701716
"lmarena-ai/vision-arena-bench-v0.1":
702717
lambda x: x["turns"][0][0]["content"]
703718
}
719+
IS_MULTIMODAL = True
704720

705721
def sample(
706722
self,
@@ -784,6 +800,64 @@ def sample(self,
784800
return sampled_requests
785801

786802

803+
# -----------------------------------------------------------------------------
804+
# MT-Bench Dataset Implementation
805+
# -----------------------------------------------------------------------------
806+
807+
808+
class MTBenchDataset(HuggingFaceDataset):
809+
"""
810+
MT-Bench Dataset.
811+
https://huggingface.co/datasets/philschmid/mt-bench
812+
813+
We create a single turn dataset for MT-Bench.
814+
This is similar to Spec decoding benchmark setup in vLLM
815+
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
816+
""" # noqa: E501
817+
818+
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
819+
SUPPORTED_DATASET_PATHS = {
820+
"philschmid/mt-bench",
821+
}
822+
823+
def sample(
824+
self,
825+
tokenizer: PreTrainedTokenizerBase,
826+
num_requests: int,
827+
output_len: Optional[int] = None,
828+
enable_multimodal_chat: bool = False,
829+
**kwargs,
830+
) -> list:
831+
output_len = (output_len
832+
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
833+
sampled_requests = []
834+
835+
for item in self.data:
836+
if len(sampled_requests) >= num_requests:
837+
break
838+
prompt = item["turns"][0]
839+
840+
# apply template
841+
prompt = tokenizer.apply_chat_template(
842+
[{
843+
"role": "user",
844+
"content": prompt
845+
}],
846+
add_generation_prompt=True,
847+
tokenize=False,
848+
)
849+
850+
prompt_len = len(tokenizer(prompt).input_ids)
851+
sampled_requests.append(
852+
SampleRequest(
853+
prompt=prompt,
854+
prompt_len=prompt_len,
855+
expected_output_len=output_len,
856+
))
857+
self.maybe_oversample_requests(sampled_requests, num_requests)
858+
return sampled_requests
859+
860+
787861
# -----------------------------------------------------------------------------
788862
# AIMO Dataset Implementation
789863
# -----------------------------------------------------------------------------
@@ -858,18 +932,18 @@ def _format_zeta_prompt(
858932
sample: dict,
859933
original_start_marker: str = "<|editable_region_start|>") -> dict:
860934
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
861-
862-
This function formats examples from the NEP dataset
863-
into prompts and expected outputs. It could be
935+
936+
This function formats examples from the NEP dataset
937+
into prompts and expected outputs. It could be
864938
further extended to support more NEP datasets.
865-
939+
866940
Args:
867-
sample: The dataset sample containing events,
941+
sample: The dataset sample containing events,
868942
inputs, and outputs.
869-
original_start_marker: The marker indicating the
870-
start of the editable region. Defaults to
943+
original_start_marker: The marker indicating the
944+
start of the editable region. Defaults to
871945
"<|editable_region_start|>".
872-
946+
873947
Returns:
874948
A dictionary with the formatted prompts and expected outputs.
875949
"""
@@ -919,3 +993,94 @@ def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
919993
break
920994
self.maybe_oversample_requests(samples, num_requests)
921995
return samples
996+
997+
998+
# -----------------------------------------------------------------------------
999+
# ASR Dataset Implementation
1000+
# -----------------------------------------------------------------------------
1001+
1002+
1003+
class ASRDataset(HuggingFaceDataset):
1004+
"""
1005+
Dataset class for processing a ASR dataset for transcription.
1006+
Tested on the following set:
1007+
1008+
+----------------+----------------------------------------+--------------------------+-----------------------------+
1009+
| Dataset | Domain | Speaking Style | hf-subset |
1010+
+----------------+----------------------------------------+--------------------------+-----------------------------+
1011+
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
1012+
| | | | release3-speaker-adaptation |
1013+
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
1014+
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
1015+
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
1016+
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
1017+
| AMI | Meetings | Spontaneous | ihm, sdm |
1018+
+----------------+----------------------------------------+--------------------------+-----------------------------+
1019+
1020+
""" # noqa: E501
1021+
1022+
SUPPORTED_DATASET_PATHS = {
1023+
"openslr/librispeech_asr",
1024+
"facebook/voxpopuli",
1025+
"LIUM/tedlium",
1026+
"edinburghcstr/ami",
1027+
"speechcolab/gigaspeech",
1028+
"kensho/spgispeech",
1029+
}
1030+
1031+
DEFAULT_OUTPUT_LEN = 128
1032+
IS_MULTIMODAL = True
1033+
1034+
# TODO Whisper-specific. Abstract interface when more models are supported.
1035+
TRANSCRIPTION_PREAMBLE = (
1036+
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>")
1037+
skip_long_audios: bool = True
1038+
1039+
def sample(
1040+
self,
1041+
tokenizer: PreTrainedTokenizerBase,
1042+
num_requests: int,
1043+
output_len: Optional[int] = None,
1044+
**kwargs,
1045+
) -> list:
1046+
try:
1047+
import librosa
1048+
except ImportError as e:
1049+
raise ImportError(
1050+
"librosa is required for ASRDataset. Please install it "
1051+
"using `pip install librosa`.") from e
1052+
1053+
output_len = (output_len
1054+
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
1055+
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
1056+
prompt_len = len(tokenizer(prompt).input_ids)
1057+
sampled_requests = []
1058+
skipped = 0
1059+
for item in self.data:
1060+
if len(sampled_requests) >= num_requests:
1061+
break
1062+
audio = item["audio"]
1063+
y, sr = audio["array"], audio["sampling_rate"]
1064+
duration_s = librosa.get_duration(y=y, sr=sr)
1065+
# Whisper max supported duration
1066+
if self.skip_long_audios and duration_s > 30:
1067+
skipped += 1
1068+
continue
1069+
1070+
mm_content = {"audio": (y, sr)}
1071+
sampled_requests.append(
1072+
SampleRequest(
1073+
prompt=prompt,
1074+
prompt_len=prompt_len,
1075+
expected_output_len=output_len,
1076+
multi_modal_data=mm_content,
1077+
))
1078+
if skipped:
1079+
logger.warning(
1080+
"%d samples discarded from dataset due to"
1081+
" their length being greater than"
1082+
" what Whisper supports.",
1083+
skipped,
1084+
)
1085+
self.maybe_oversample_requests(sampled_requests, num_requests)
1086+
return sampled_requests

0 commit comments

Comments
 (0)