Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 42 additions & 33 deletions vllm/benchmarks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
help="Do not oversample if the dataset has " \
"fewer samples than num-prompts.",
)
parser.add_argument(
"--skip-chat-template",
action="store_true",
help=
"Skip applying chat template to prompt for datasets that support it.",
)

# group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options")
Expand All @@ -1161,12 +1167,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
help=
"Number of output tokens per request, used only for custom dataset.",
)
custom_group.add_argument(
"--custom-skip-chat-template",
action="store_true",
help=
"Skip applying chat template to prompt, used only for custom dataset.",
)

spec_bench_group = parser.add_argument_group("spec bench dataset options")
spec_bench_group.add_argument(
Expand Down Expand Up @@ -1435,7 +1435,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template,
skip_chat_template=args.skip_chat_template,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
)
Expand Down Expand Up @@ -1576,6 +1576,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
skip_chat_template=args.skip_chat_template,
**hf_kwargs
)

Expand Down Expand Up @@ -1815,7 +1816,6 @@ def load_data(self) -> None:

def sample(self, **kwargs) -> list:
# leverage CustomDataset sample
kwargs["skip_chat_template"] = False
return super().sample(**kwargs)


Expand Down Expand Up @@ -2221,6 +2221,7 @@ def sample(self,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs) -> list:
Expand All @@ -2236,14 +2237,15 @@ def sample(self,
)

# apply template
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)

prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
Expand Down Expand Up @@ -2284,6 +2286,7 @@ def sample(
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
Expand All @@ -2298,14 +2301,18 @@ def sample(
prompt = item["turns"][0]

# apply template
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)

# REMOVE
print(f"Prompt {i}: {prompt}\n---")

prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
Expand Down Expand Up @@ -2349,6 +2356,7 @@ def sample(
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
skip_chat_template: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
min_distance: float = 0.0,
Expand All @@ -2372,7 +2380,7 @@ def sample(

# template copied from
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
instruction = f"""Given a code file, please apply the change requests and generate the new file.
prompt = f"""Given a code file, please apply the change requests and generate the new file.

Original file:
```python
Expand All @@ -2385,14 +2393,15 @@ def sample(
Please generate the new code file in the "New file" section below.""" # noqa: E501

# apply template
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": instruction
}],
add_generation_prompt=True,
tokenize=False,
)
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)

prompt_len = len(tokenizer(prompt).input_ids)

Expand Down