Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9369a37
Add draft for livecodebench code generation
plaguss Feb 10, 2025
2001e7b
Add extra argument version_tag
plaguss Feb 10, 2025
fece552
Fix import name
plaguss Feb 10, 2025
e46fc2a
Remove unused typed dict
plaguss Feb 10, 2025
6a3c007
Checkpoint, not ready yet, try simplifying code running and reuse pas…
plaguss Feb 10, 2025
987eb2a
Add notes for expected values
plaguss Feb 11, 2025
42fb0f5
Pass version tag to downloader
plaguss Feb 11, 2025
b700dc4
Modify helper module and remove dataset version tag
plaguss Feb 14, 2025
29b2bbe
Remove version_tag
plaguss Feb 14, 2025
a60e662
Initial version for lcb:codegeneration
plaguss Feb 14, 2025
05a7f01
Remove outdated argument docs
plaguss Feb 14, 2025
deea663
Remove hardcoded system prompt and pass it via arg
plaguss Feb 14, 2025
a2863f9
Merge branch 'main' into lcb-codegeneration
plaguss Feb 14, 2025
44f45b5
Add kwargs to allow passing other arguments
plaguss Feb 14, 2025
127b4cd
Make generic function to parse the metric name and obtain the number …
plaguss Feb 14, 2025
a372e05
Change metric name to make it more informative
plaguss Feb 14, 2025
53ab417
Add experimental way of passing the number of samples for a metric fr…
plaguss Feb 14, 2025
f6a7c4f
Add more processes to run the tests
plaguss Feb 16, 2025
d6abcd0
Allow reading the generation parameters from the CLI
plaguss Feb 17, 2025
158d660
Update parsing arguments from CLI
plaguss Feb 17, 2025
54fa032
Remove dead code and fix test value
plaguss Feb 17, 2025
4a0fe89
Fix num_samples update
plaguss Feb 17, 2025
f945fdf
Add docs for the new metric_options
plaguss Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/source/use-vllm-as-backend.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,32 @@ model: # Model specific parameters
> [!WARNING]
> In the case of OOM issues, you might need to reduce the context size of the
> model as well as reduce the `gpu_memory_utilization` parameter.


## Dynamically changing the metric configuration

For special kinds of metrics like `Pass@K` or LiveCodeBench's `codegen` metric, you may need to pass specific values like the number of
generations. This can be done in the `yaml` file in the following way:

```yaml
model: # Model specific parameters
base_params:
model_args: "pretrained=HuggingFaceTB/SmolLM-1.7B,revision=main,dtype=bfloat16" # Model args that you would pass in the command line
generation: # Generation specific parameters
temperature: 0.3
repetition_penalty: 1.0
frequency_penalty: 0.0
presence_penalty: 0.0
seed: 42
top_k: 0
min_p: 0.0
top_p: 0.9
metric_options: # Optional metric arguments
codegen_pass@1:16:
num_samples: 16
```

An optional key `metric_options` can be passed in the yaml file,
using the name of the metric or metrics, as defined in the `Metric.metric_name`.
In this case, the `codegen_pass@1:16` metric defined in our tasks will have the `num_samples` updated to 16,
independently of the number defined by default.
5 changes: 4 additions & 1 deletion src/lighteval/main_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,11 @@ def vllm(
with open(model_args, "r") as f:
config = yaml.safe_load(f)["model"]
model_args = config["base_params"]["model_args"]
metric_options = config.get("metric_options", {})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some docs for this ?

generation_parameters = GenerationParameters.from_dict(config)
else:
generation_parameters = GenerationParameters()
generation_parameters = GenerationParameters.from_model_args(model_args)
metric_options = {}

model_args_dict: dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in model_args.split(",")}
model_config = VLLMModelConfig(**model_args_dict, generation_parameters=generation_parameters)
Expand All @@ -146,6 +148,7 @@ def vllm(
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model_config=model_config,
metric_options=metric_options,
)

pipeline.evaluate()
Expand Down
28 changes: 28 additions & 0 deletions src/lighteval/models/model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,34 @@ def from_dict(cls, config_dict: dict):
"""
return GenerationParameters(**config_dict.get("generation", {}))

@classmethod
def from_model_args(cls, model_args: str):
"""Creates a GenerationParameters object from a model_args string.

It's used when the model_args are passed as a string in the command line.
The generation parameters must follow the following format (at any place in the string):
"generation_parameters={key1:value1,key2=value2}"

Args:
model_args (str): A string like the following:
"pretrained=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B,dtype=float16,max_model_length=32768,generation={temperature:0.7,top_p:5}"
"""

def parse_model_args(model_args):
import json
import re

pattern = re.compile(r"(\w+)=(\{.*\}|[^,]+)")
matches = pattern.findall(model_args)
for key, value in matches:
key = key.strip()
if key == "generation_parameters":
gen_params = re.sub(r"(\w+):", r'"\1":', value)
return json.loads(gen_params)

params: dict = parse_model_args(model_args) or {}
return GenerationParameters(**params)

def to_litellm_dict(self) -> dict:
"""Selects relevant generation and sampling parameters for litellm models.
Doc: https://docs.litellm.ai/docs/completion/input#input-params-1
Expand Down
19 changes: 19 additions & 0 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
evaluation_tracker: EvaluationTracker,
model_config=None,
model=None,
metric_options=None,
):
if not (model or model_config):
raise ValueError("Must provide either a model or model config when creating a pipeline.")
Expand All @@ -145,6 +146,7 @@ def __init__(

self.model_config = model_config
self.evaluation_tracker = evaluation_tracker
self._metric_options = metric_options or {}
self.accelerator, self.parallel_context = self._init_parallelism_manager()
self.model = self._init_model(model_config, model)

Expand Down Expand Up @@ -209,6 +211,10 @@ def _init_tasks_and_requests(self, tasks: str):
)
task_names_list, fewshots_dict = taskinfo_selector(tasks, registry)
task_dict = registry.get_task_dict(task_names_list)
# If there are metric_options defined from the yaml file,
# review if they have to be updated.
if self._metric_options:
self._update_num_samples(task_dict)
LightevalTask.load_datasets(list(task_dict.values()), self.pipeline_parameters.dataset_loading_processes)

self.evaluation_tracker.task_config_logger.log(task_dict)
Expand All @@ -230,6 +236,19 @@ def _init_tasks_and_requests(self, tasks: str):
self.requests = requests
self.docs = docs

def _update_num_samples(self, task_dict: dict[str, LightevalTask]):
"""Helper function to update the num_samples of a given metric via the yaml file.
As it has to be done at the metric level, it's better to update the value per metric.
It will add a num_samples to the already defined metrics' num_samples if defined in the yaml file.
As later when constructing the requests the max is taken over the num_samples, this is valid.
"""
for _, task in task_dict.items():
for metric in task.metrics:
if metric_data := self._metric_options.get(metric.metric_name, None):
num_samples = metric_data.get("num_samples", None)
if num_samples:
task.num_samples = [num_samples]

def _init_random_seeds(self):
logger.info("--- INIT SEEDS ---")
random.seed(1234)
Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/tasks/extended/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@

if can_load_extended_tasks():
import lighteval.tasks.extended.ifeval.main as ifeval
import lighteval.tasks.extended.lcb.main as lcb
import lighteval.tasks.extended.mix_eval.main as mix_eval
import lighteval.tasks.extended.mt_bench.main as mt_bench
import lighteval.tasks.extended.olympiade_bench.main as olympiad_bench
import lighteval.tasks.extended.tiny_benchmarks.main as tiny_benchmarks

AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval, olympiad_bench]
AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval, olympiad_bench, lcb]

else:
AVAILABLE_EXTENDED_TASKS_MODULES = []
Loading