@@ -62,6 +62,7 @@ class SampleRequest:
6262
6363class BenchmarkDataset (ABC ):
6464 DEFAULT_SEED = 0
65+ IS_MULTIMODAL = False
6566
6667 def __init__ (
6768 self ,
@@ -316,13 +317,15 @@ def sample(
316317 )
317318
318319 vocab_size = tokenizer .vocab_size
320+ num_special_tokens = tokenizer .num_special_tokens_to_add ()
321+ real_input_len = input_len - num_special_tokens
319322
320323 prefix_token_ids = (np .random .randint (
321324 0 , vocab_size , size = prefix_len ).tolist () if prefix_len > 0 else [])
322325
323326 # New sampling logic: [X * (1 - b), X * (1 + b)]
324- input_low = int (input_len * (1 - range_ratio ))
325- input_high = int (input_len * (1 + range_ratio ))
327+ input_low = int (real_input_len * (1 - range_ratio ))
328+ input_high = int (real_input_len * (1 + range_ratio ))
326329 output_low = int (output_len * (1 - range_ratio ))
327330 output_high = int (output_len * (1 + range_ratio ))
328331
@@ -345,6 +348,17 @@ def sample(
345348 vocab_size ).tolist ()
346349 token_sequence = prefix_token_ids + inner_seq
347350 prompt = tokenizer .decode (token_sequence )
351+ # After decoding the prompt we have to encode and decode it again.
352+ # This is done because in some cases N consecutive tokens
353+ # give a string tokenized into != N number of tokens.
354+ # For example for GPT2Tokenizer:
355+ # [6880, 6881] -> ['Ġcalls', 'here'] ->
356+ # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
357+ # To avoid uncontrolled change of the prompt length,
358+ # the encoded sequence is truncated before being decode again.
359+ re_encoded_sequence = tokenizer .encode (
360+ prompt , add_special_tokens = False )[:input_lens [i ]]
361+ prompt = tokenizer .decode (re_encoded_sequence )
348362 total_input_len = prefix_len + int (input_lens [i ])
349363 requests .append (
350364 SampleRequest (
@@ -637,6 +651,7 @@ class ConversationDataset(HuggingFaceDataset):
637651 SUPPORTED_DATASET_PATHS = {
638652 'lmms-lab/LLaVA-OneVision-Data' , 'Aeala/ShareGPT_Vicuna_unfiltered'
639653 }
654+ IS_MULTIMODAL = True
640655
641656 def sample (self ,
642657 tokenizer : PreTrainedTokenizerBase ,
@@ -701,6 +716,7 @@ class VisionArenaDataset(HuggingFaceDataset):
701716 "lmarena-ai/vision-arena-bench-v0.1" :
702717 lambda x : x ["turns" ][0 ][0 ]["content" ]
703718 }
719+ IS_MULTIMODAL = True
704720
705721 def sample (
706722 self ,
@@ -784,6 +800,64 @@ def sample(self,
784800 return sampled_requests
785801
786802
803+ # -----------------------------------------------------------------------------
804+ # MT-Bench Dataset Implementation
805+ # -----------------------------------------------------------------------------
806+
807+
808+ class MTBenchDataset (HuggingFaceDataset ):
809+ """
810+ MT-Bench Dataset.
811+ https://huggingface.co/datasets/philschmid/mt-bench
812+
813+ We create a single turn dataset for MT-Bench.
814+ This is similar to Spec decoding benchmark setup in vLLM
815+ https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
816+ """ # noqa: E501
817+
818+ DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
819+ SUPPORTED_DATASET_PATHS = {
820+ "philschmid/mt-bench" ,
821+ }
822+
823+ def sample (
824+ self ,
825+ tokenizer : PreTrainedTokenizerBase ,
826+ num_requests : int ,
827+ output_len : Optional [int ] = None ,
828+ enable_multimodal_chat : bool = False ,
829+ ** kwargs ,
830+ ) -> list :
831+ output_len = (output_len
832+ if output_len is not None else self .DEFAULT_OUTPUT_LEN )
833+ sampled_requests = []
834+
835+ for item in self .data :
836+ if len (sampled_requests ) >= num_requests :
837+ break
838+ prompt = item ["turns" ][0 ]
839+
840+ # apply template
841+ prompt = tokenizer .apply_chat_template (
842+ [{
843+ "role" : "user" ,
844+ "content" : prompt
845+ }],
846+ add_generation_prompt = True ,
847+ tokenize = False ,
848+ )
849+
850+ prompt_len = len (tokenizer (prompt ).input_ids )
851+ sampled_requests .append (
852+ SampleRequest (
853+ prompt = prompt ,
854+ prompt_len = prompt_len ,
855+ expected_output_len = output_len ,
856+ ))
857+ self .maybe_oversample_requests (sampled_requests , num_requests )
858+ return sampled_requests
859+
860+
787861# -----------------------------------------------------------------------------
788862# AIMO Dataset Implementation
789863# -----------------------------------------------------------------------------
@@ -858,18 +932,18 @@ def _format_zeta_prompt(
858932 sample : dict ,
859933 original_start_marker : str = "<|editable_region_start|>" ) -> dict :
860934 """Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
861-
862- This function formats examples from the NEP dataset
863- into prompts and expected outputs. It could be
935+
936+ This function formats examples from the NEP dataset
937+ into prompts and expected outputs. It could be
864938 further extended to support more NEP datasets.
865-
939+
866940 Args:
867- sample: The dataset sample containing events,
941+ sample: The dataset sample containing events,
868942 inputs, and outputs.
869- original_start_marker: The marker indicating the
870- start of the editable region. Defaults to
943+ original_start_marker: The marker indicating the
944+ start of the editable region. Defaults to
871945 "<|editable_region_start|>".
872-
946+
873947 Returns:
874948 A dictionary with the formatted prompts and expected outputs.
875949 """
@@ -919,3 +993,94 @@ def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
919993 break
920994 self .maybe_oversample_requests (samples , num_requests )
921995 return samples
996+
997+
998+ # -----------------------------------------------------------------------------
999+ # ASR Dataset Implementation
1000+ # -----------------------------------------------------------------------------
1001+
1002+
1003+ class ASRDataset (HuggingFaceDataset ):
1004+ """
1005+ Dataset class for processing a ASR dataset for transcription.
1006+ Tested on the following set:
1007+
1008+ +----------------+----------------------------------------+--------------------------+-----------------------------+
1009+ | Dataset | Domain | Speaking Style | hf-subset |
1010+ +----------------+----------------------------------------+--------------------------+-----------------------------+
1011+ | TED-LIUM | TED talks | Oratory | release1, release2, release3|
1012+ | | | | release3-speaker-adaptation |
1013+ | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
1014+ | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
1015+ | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
1016+ | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
1017+ | AMI | Meetings | Spontaneous | ihm, sdm |
1018+ +----------------+----------------------------------------+--------------------------+-----------------------------+
1019+
1020+ """ # noqa: E501
1021+
1022+ SUPPORTED_DATASET_PATHS = {
1023+ "openslr/librispeech_asr" ,
1024+ "facebook/voxpopuli" ,
1025+ "LIUM/tedlium" ,
1026+ "edinburghcstr/ami" ,
1027+ "speechcolab/gigaspeech" ,
1028+ "kensho/spgispeech" ,
1029+ }
1030+
1031+ DEFAULT_OUTPUT_LEN = 128
1032+ IS_MULTIMODAL = True
1033+
1034+ # TODO Whisper-specific. Abstract interface when more models are supported.
1035+ TRANSCRIPTION_PREAMBLE = (
1036+ "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" )
1037+ skip_long_audios : bool = True
1038+
1039+ def sample (
1040+ self ,
1041+ tokenizer : PreTrainedTokenizerBase ,
1042+ num_requests : int ,
1043+ output_len : Optional [int ] = None ,
1044+ ** kwargs ,
1045+ ) -> list :
1046+ try :
1047+ import librosa
1048+ except ImportError as e :
1049+ raise ImportError (
1050+ "librosa is required for ASRDataset. Please install it "
1051+ "using `pip install librosa`." ) from e
1052+
1053+ output_len = (output_len
1054+ if output_len is not None else self .DEFAULT_OUTPUT_LEN )
1055+ prompt = ASRDataset .TRANSCRIPTION_PREAMBLE
1056+ prompt_len = len (tokenizer (prompt ).input_ids )
1057+ sampled_requests = []
1058+ skipped = 0
1059+ for item in self .data :
1060+ if len (sampled_requests ) >= num_requests :
1061+ break
1062+ audio = item ["audio" ]
1063+ y , sr = audio ["array" ], audio ["sampling_rate" ]
1064+ duration_s = librosa .get_duration (y = y , sr = sr )
1065+ # Whisper max supported duration
1066+ if self .skip_long_audios and duration_s > 30 :
1067+ skipped += 1
1068+ continue
1069+
1070+ mm_content = {"audio" : (y , sr )}
1071+ sampled_requests .append (
1072+ SampleRequest (
1073+ prompt = prompt ,
1074+ prompt_len = prompt_len ,
1075+ expected_output_len = output_len ,
1076+ multi_modal_data = mm_content ,
1077+ ))
1078+ if skipped :
1079+ logger .warning (
1080+ "%d samples discarded from dataset due to"
1081+ " their length being greater than"
1082+ " what Whisper supports." ,
1083+ skipped ,
1084+ )
1085+ self .maybe_oversample_requests (sampled_requests , num_requests )
1086+ return sampled_requests
0 commit comments