diff --git a/docs/source/evaluating-a-custom-model.mdx b/docs/source/evaluating-a-custom-model.mdx index 2c6ba8dda..d5e7c8651 100644 --- a/docs/source/evaluating-a-custom-model.mdx +++ b/docs/source/evaluating-a-custom-model.mdx @@ -15,7 +15,7 @@ Here's a basic example: ```python from lighteval.models.abstract_model import LightevalModel from lighteval.models.model_output import ModelResponse -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached class MyCustomModel(LightevalModel): @@ -26,18 +26,18 @@ class MyCustomModel(LightevalModel): # Enable caching (recommended) self._cache = SampleCache(config) - @cached("predictions") # Enable caching for better performance - def greedy_until(self, docs: list[Doc]) -> list[ModelResponse]: + @cached(SamplingMethod.GENERATIVE) + def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]: # Implement generation logic pass - @cached("predictions") # Enable caching for better performance - def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: + @cached(SamplingMethod.LOGPROBS) + def loglikelihood(self, docs: List[Doc]) -> List[ModelResponse]: # Implement loglikelihood computation pass - @cached("predictions") # Enable caching for better performance - def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: + @cached(SamplingMethod.PERPLEXITY) + def loglikelihood_rolling(self, docs: List[Doc]) -> List[ModelResponse]: # Implement rolling loglikelihood computation pass ``` @@ -179,13 +179,12 @@ def __init__(self, config): self._cache = SampleCache(config) ``` -### Step 3: Add Cache Decorators -Add cache decorators to your prediction methods: -```python -@cached("predictions") -def greedy_until(self, docs: list[Doc]) -> list[ModelResponse]: - # Your implementation... -``` +3. Add cache decorators to your prediction methods: + ```python + @cached(SamplingMethod.GENERATIVE) + def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]: + # Your implementation... + ``` For detailed information about the caching system, see the [Caching Documentation](caching). diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py index 28a4b90c2..b87a83a9f 100644 --- a/src/lighteval/metrics/metrics_corpus.py +++ b/src/lighteval/metrics/metrics_corpus.py @@ -50,6 +50,17 @@ class CorpusLevelComputation(ABC): def compute_corpus(self): raise NotImplementedError + def __str__(self): + attrs = vars(self) + attr_strs = [] + for k, v in attrs.items(): + if callable(v): + val_str = v.__name__ + else: + val_str = str(v) + attr_strs.append(f"{k}={val_str}") + return f"{self.__class__.__name__}({', '.join(attr_strs)})" + # General aggregations class MatthewsCorrCoef(CorpusLevelComputation): diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 635f9d9de..38a0e2f52 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -66,6 +66,17 @@ class SampleLevelComputation(ABC): def compute(self, model_response: ModelResponse, doc: Doc, **kwargs): raise NotImplementedError + def __str__(self): + attrs = vars(self) + attr_strs = [] + for k, v in attrs.items(): + if callable(v): + val_str = v.__name__ + else: + val_str = str(v) + attr_strs.append(f"{k}={val_str}") + return f"{self.__class__.__name__}({', '.join(attr_strs)})" + class ExactMatches(SampleLevelComputation): def __init__( @@ -1109,10 +1120,11 @@ def __init__( self.strip_strings = strip_strings if callable(sample_scoring_function): - self.score_sample = sample_scoring_function + self.compute_score = sample_scoring_function self.type_exact_match = None elif isinstance(sample_scoring_function, SampleLevelComputation): self.score_sample = sample_scoring_function.compute + self.type_exact_match = None else: if isinstance(sample_scoring_function, str): if sample_scoring_function not in ["prefix", "suffix", "full"]: @@ -1199,7 +1211,7 @@ def __init__(self, k: int | None = None, **kwargs): k (int): The number of top choices to consider. **kwargs: Additional keyword arguments. """ - super().__init__(kwargs) + super().__init__(**kwargs) self.k = k self.attribute_must_be_set = ["k"] @@ -1280,7 +1292,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: elif len(predictions) < self.n: logger.warning(f"Number of predictions is less than {self.n} for pass@k.") - processed_choices = [self.preprocess(text=g) for g in doc.choices] + processed_choices = [self.preprocess(g) for g in doc.choices] new_doc = Doc( choices=processed_choices, query=doc.query, @@ -1289,7 +1301,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: all_scores = [] for pred in predictions[: self.n]: - cur_pred = self.preprocess(text=pred) + cur_pred = self.preprocess(pred) new_model_response = ModelResponse( text=[cur_pred], ) diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py index 725ba6fc6..8d79cb448 100644 --- a/src/lighteval/metrics/sample_preparator.py +++ b/src/lighteval/metrics/sample_preparator.py @@ -81,6 +81,17 @@ def prepare(doc: Doc, model_response: ModelResponse, **kwargs): predictions = model_response.final_text return GenerativeCorpusMetricInput(golds=golds, preds=predictions) + def __str__(self): + attrs = vars(self) + attr_strs = [] + for k, v in attrs.items(): + if callable(v): + val_str = v.__name__ + else: + val_str = str(v) + attr_strs.append(f"{k}={val_str}") + return f"{self.__class__.__name__}({', '.join(attr_strs)})" + class LoglikelihoodPreparator(Preparator): def __init__(self, is_single_token: bool = False): diff --git a/src/lighteval/metrics/utils/metric_utils.py b/src/lighteval/metrics/utils/metric_utils.py index c23a7854c..e57e56724 100644 --- a/src/lighteval/metrics/utils/metric_utils.py +++ b/src/lighteval/metrics/utils/metric_utils.py @@ -95,6 +95,10 @@ def __call__(self, sample_params: dict | None): self.metric_name = f"{self.metric_name}_with_{sample_params_name}" return self + @staticmethod + def get_allowed_types_for_metrics(): + return (SampleLevelComputation, Preparator, CorpusLevelComputation, Callable) + @dataclass class MetricGrouping(Metric): diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py index 46487656e..81d725e6a 100644 --- a/src/lighteval/models/abstract_model.py +++ b/src/lighteval/models/abstract_model.py @@ -46,6 +46,8 @@ class ModelConfig(BaseModel, extra="forbid"): as well as shared attributes that are used by all models like generation parameters and system prompts. Attributes: + model_name (str): + The model name or unique id generation_parameters (GenerationParameters): Configuration parameters that control text generation behavior, including temperature, top_p, max_new_tokens, etc. Defaults to empty GenerationParameters. @@ -80,6 +82,8 @@ class ModelConfig(BaseModel, extra="forbid"): ``` """ + model_name: str = None + generation_parameters: GenerationParameters = GenerationParameters() system_prompt: str | None = None cache_dir: str = "~/.cache/huggingface/lighteval" diff --git a/src/lighteval/models/dummy/dummy_model.py b/src/lighteval/models/dummy/dummy_model.py index 157996150..e0a13b589 100644 --- a/src/lighteval/models/dummy/dummy_model.py +++ b/src/lighteval/models/dummy/dummy_model.py @@ -28,7 +28,7 @@ from lighteval.models.abstract_model import LightevalModel, ModelConfig from lighteval.models.model_output import ModelResponse -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached @@ -87,11 +87,11 @@ def add_special_tokens(self): def max_length(self) -> int: return 2048 - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) def greedy_until(self, docs: list[Doc]) -> list[ModelResponse]: return [ModelResponse(text=["random baseline"]) for _ in range(len(docs))] - @cached("predictions") + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: model_responses = [] for doc in docs: @@ -104,7 +104,7 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return model_responses - @cached("predictions") + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: model_responses = [] for doc in docs: diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 7bede3851..6b08be575 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -48,7 +48,7 @@ from lighteval.models.abstract_model import LightevalModel, ModelConfig from lighteval.models.model_output import ModelResponse from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached @@ -555,7 +555,7 @@ def _process_batch_logprob(self, docs: list[Doc], rolling: bool = False) -> list for context, doc in zip(contexts, docs) ] - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: List[Doc], @@ -599,11 +599,11 @@ def _greedy_until(self, docs: List[Doc]) -> list[ModelResponse]: return dataset.get_original_order(results) - @cached("predictions") + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return self._loglikelihood(docs, rolling=False) - @cached("predictions") + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc], override_bs=None) -> list[ModelResponse]: return self._loglikelihood(docs, rolling=True) diff --git a/src/lighteval/models/endpoints/inference_providers_model.py b/src/lighteval/models/endpoints/inference_providers_model.py index 20d3c94e4..bda9a517f 100644 --- a/src/lighteval/models/endpoints/inference_providers_model.py +++ b/src/lighteval/models/endpoints/inference_providers_model.py @@ -35,7 +35,7 @@ from lighteval.models.abstract_model import LightevalModel, ModelConfig from lighteval.models.model_output import ModelResponse from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached @@ -196,7 +196,7 @@ async def bounded_api_call(prompt, num_samples): return results - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -253,14 +253,14 @@ def max_length(self) -> int: logger.warning("Tokenizer was not correctly loaded. Max model context length is assumed to be 30K tokens") return 30000 - @cached("predictions") + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError - @cached("predictions") + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index 544123e00..00f7c8779 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -30,7 +30,7 @@ from lighteval.models.abstract_model import LightevalModel, ModelConfig from lighteval.models.model_output import ModelResponse from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import is_litellm_available @@ -258,7 +258,7 @@ def __call_api_parallel( return results - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -323,14 +323,14 @@ def max_length(self) -> int: """Return the maximum sequence length of the model.""" return 4096 - @cached("predictions") + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError - @cached("predictions") + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError diff --git a/src/lighteval/models/nanotron/nanotron_model.py b/src/lighteval/models/nanotron/nanotron_model.py index 686111e04..310843d32 100644 --- a/src/lighteval/models/nanotron/nanotron_model.py +++ b/src/lighteval/models/nanotron/nanotron_model.py @@ -48,6 +48,7 @@ from lighteval.models.transformers.transformers_model import LightevalModel from lighteval.tasks.requests import ( Doc, + SamplingMethod, ) from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import is_nanotron_available @@ -483,7 +484,7 @@ def _check_continuations_start_space(self, continuation: str) -> str: continuation = continuation.lstrip() return continuation - @cached("predictions") + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, requests: List[Doc]) -> List[ModelResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. @@ -506,7 +507,7 @@ def loglikelihood(self, requests: List[Doc]) -> List[ModelResponse]: disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0), ) - @cached("predictions") + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, requests: List[Doc]) -> List[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" for request in tqdm( @@ -941,7 +942,7 @@ def _loglikelihood_tokens( return dataset.get_original_order(res) @torch.inference_mode() - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, requests: List[Doc], diff --git a/src/lighteval/models/sglang/sglang_model.py b/src/lighteval/models/sglang/sglang_model.py index 3c39315a8..fe37c64f9 100644 --- a/src/lighteval/models/sglang/sglang_model.py +++ b/src/lighteval/models/sglang/sglang_model.py @@ -33,7 +33,7 @@ from lighteval.models.model_output import ModelResponse from lighteval.models.utils import _simplify_name, uses_chat_template from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import is_sglang_available @@ -221,7 +221,7 @@ def _create_auto_tokenizer(self, config: SGLangModelConfig): tokenizer.pad_token = tokenizer.eos_token return tokenizer - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -347,7 +347,7 @@ def _generate( ) return outputs - @cached("predictions") + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return self._loglikelihood_tokens(docs) @@ -416,6 +416,6 @@ def _loglikelihood_tokens( res.append(answer) return dataset.get_original_order(res) - @cached("predictions") + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: raise NotImplementedError() diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index 3d0d7c12b..0fb6df464 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -52,7 +52,7 @@ ) from lighteval.models.utils import _get_dtype, _get_model_sha, _simplify_name, uses_chat_template from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import ( is_accelerate_available, @@ -738,7 +738,7 @@ def _padded_greedy_until( return dataset.get_original_order(results) - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -867,7 +867,7 @@ def _generate( else: return self._generate_padded(**kwargs) - @cached("predictions") + @cached(SamplingMethod.LOGPROBS) def loglikelihood( self, docs: list[Doc], @@ -883,7 +883,7 @@ def loglikelihood( """ return self._loglikelihood_tokens(docs) - @cached("predictions") + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling( self, docs: list[Doc], diff --git a/src/lighteval/models/transformers/vlm_transformers_model.py b/src/lighteval/models/transformers/vlm_transformers_model.py index 61f9d58ab..3da1290be 100644 --- a/src/lighteval/models/transformers/vlm_transformers_model.py +++ b/src/lighteval/models/transformers/vlm_transformers_model.py @@ -44,7 +44,7 @@ from lighteval.models.model_output import ModelResponse from lighteval.models.utils import _get_dtype, _get_model_sha, _simplify_name from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import ( is_accelerate_available, @@ -338,7 +338,7 @@ def _init_max_length(self) -> int: return 2048 - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -428,14 +428,14 @@ def _greedy_until( return dataset.get_original_order(results) - @cached("predictions") + @cached(SamplingMethod.LOGPROBS) def loglikelihood( self, docs: list[Doc], ) -> list[ModelResponse]: raise NotImplementedError() - @cached("predictions") + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling( self, docs: list[Doc], diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index eac77e640..31ec8b5a3 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -36,7 +36,7 @@ from lighteval.models.model_output import ModelResponse from lighteval.models.utils import _simplify_name, uses_chat_template from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import is_vllm_available @@ -255,6 +255,7 @@ def _create_auto_model(self, config: VLLMModelConfig) -> Optional[LLM]: "seed": int(config.seed), "max_num_seqs": int(config.max_num_seqs), "max_num_batched_tokens": int(config.max_num_batched_tokens), + "enforce_eager": True, } if config.quantization is not None: @@ -296,7 +297,7 @@ def _create_auto_tokenizer(self, config: VLLMModelConfig): tokenizer.pad_token = tokenizer.eos_token return tokenizer - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -453,7 +454,7 @@ def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, r return outputs - @cached("predictions") + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return self._loglikelihood_tokens(docs) @@ -481,7 +482,8 @@ def _loglikelihood_tokens( tokenized_contexts_batch.append(tokenized_context) # Left truncate the inputs to the maximum length - inputs = [input[-self.max_length :] for input in inputs] + if self.max_length: # can be None if the model is initialized with ray + inputs = [input[-self.max_length :] for input in inputs] outputs = self._generate(inputs, generate=False) flat_index = 0 @@ -521,7 +523,7 @@ def _loglikelihood_tokens( return dataset.get_original_order(res) - @cached("predictions") + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: raise NotImplementedError() @@ -616,7 +618,7 @@ async def _async_batch(self, docs: list[Doc], generative: bool) -> list: results = await asyncio.gather(*processed_requests) return results - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) async def greedy_until( self, docs: list[Doc], @@ -650,7 +652,7 @@ async def greedy_until( return results - @cached("predictions") + @cached(SamplingMethod.LOGPROBS) async def loglikelihood( self, docs: list[Doc], diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index 71b1efd4a..f79e8c910 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -176,6 +176,8 @@ def __init__( self.model_config = model_config self.accelerator, self.parallel_context = self._init_parallelism_manager() self.model = self._init_model(model_config, model) + # Must occur after model and task init + self.model._cache._init_registry(self.registry) # Must occur after model init self._init_accelerator_seeds() @@ -241,10 +243,10 @@ def _init_tasks_and_requests(self, tasks: str): logger.info("--- LOADING TASKS ---") # The registry contains all the potential tasks - registry = Registry(tasks=tasks, custom_tasks=self.pipeline_parameters.custom_tasks_directory) + self.registry = Registry(tasks=tasks, custom_tasks=self.pipeline_parameters.custom_tasks_directory) # load the tasks from the configs and their datasets - self.tasks_dict: dict[str, LightevalTask] = registry.load_tasks() + self.tasks_dict: dict[str, LightevalTask] = self.registry.load_tasks() LightevalTask.load_datasets(self.tasks_dict, self.pipeline_parameters.dataset_loading_processes) self.documents_dict = { task.full_name: task.get_docs(self.pipeline_parameters.max_samples) for _, task in self.tasks_dict.items() diff --git a/src/lighteval/tasks/default_prompts.py b/src/lighteval/tasks/default_prompts.py index b1d766155..a78860168 100644 --- a/src/lighteval/tasks/default_prompts.py +++ b/src/lighteval/tasks/default_prompts.py @@ -1862,6 +1862,27 @@ def mmlu_helm(line, task_name: str = None): ) +def mmlu_redux_2(line, topic, task_name: str = None): + """ + Ref: https://arxiv.org/abs/2406.04127 + """ + query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" + query += line["question"] + "\n" + query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) + query += "Answer: " + + # Handle answer format - MMLU-Redux-2 uses integer indices directly + gold_ix = line["answer"] if isinstance(line["answer"], int) else int(line["answer"]) + + return Doc( + task_name=task_name, + query=query, + choices=LETTER_INDICES[: len(line["choices"])], + gold_index=gold_ix, + instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", + ) + + def mmlu_qa_abstract_algebra(line, task_name: str = None): return mmlu_qa(line, "abstract_algebra", task_name) diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index 8e3f286be..7abecc929 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -22727,3 +22727,145 @@ stop_sequence=["\n"], version=0, ) + +# MMLU-Redux-2 Tasks +_MMLU_REDUX_2_SUBSETS = [ + "abstract_algebra", + "anatomy", + "astronomy", + "business_ethics", + "clinical_knowledge", + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_medicine", + "college_physics", + "computer_security", + "conceptual_physics", + "econometrics", + "electrical_engineering", + "elementary_mathematics", + "formal_logic", + "global_facts", + "high_school_biology", + "high_school_chemistry", + "high_school_computer_science", + "high_school_european_history", + "high_school_geography", + "high_school_government_and_politics", + "high_school_macroeconomics", + "high_school_mathematics", + "high_school_microeconomics", + "high_school_physics", + "high_school_psychology", + "high_school_statistics", + "high_school_us_history", + "high_school_world_history", + "human_aging", + "human_sexuality", + "international_law", + "jurisprudence", + "logical_fallacies", + "machine_learning", + "management", + "marketing", + "medical_genetics", + "miscellaneous", + "moral_disputes", + "moral_scenarios", + "nutrition", + "philosophy", + "prehistory", + "professional_accounting", + "professional_law", + "professional_medicine", + "professional_psychology", + "public_relations", + "security_studies", + "sociology", + "us_foreign_policy", + "virology", + "world_religions", +] + + +_mmlu_redux_2_tasks = { + subset: LightevalTaskConfig( + name=f"mmlu_redux_2:{subset}", + suite=["lighteval"], + prompt_function=lambda line, task_name=None, s=subset: prompt.mmlu_redux_2(line, s, task_name), + hf_repo="edinburgh-dawg/mmlu-redux-2.0", + hf_subset=subset, + hf_avail_splits=["test"], + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + generation_size=1, + metrics=[ + Metrics.loglikelihood_acc, + Metrics.pass_at_k_letters(sample_params={"k": 1}), + ], + stop_sequence=["\n"], + version=0, + ) + for subset in _MMLU_REDUX_2_SUBSETS +} + +mmlu_redux_2_abstract_algebra = _mmlu_redux_2_tasks["abstract_algebra"] +mmlu_redux_2_anatomy = _mmlu_redux_2_tasks["anatomy"] +mmlu_redux_2_astronomy = _mmlu_redux_2_tasks["astronomy"] +mmlu_redux_2_business_ethics = _mmlu_redux_2_tasks["business_ethics"] +mmlu_redux_2_clinical_knowledge = _mmlu_redux_2_tasks["clinical_knowledge"] +mmlu_redux_2_college_biology = _mmlu_redux_2_tasks["college_biology"] +mmlu_redux_2_college_chemistry = _mmlu_redux_2_tasks["college_chemistry"] +mmlu_redux_2_college_computer_science = _mmlu_redux_2_tasks["college_computer_science"] +mmlu_redux_2_college_mathematics = _mmlu_redux_2_tasks["college_mathematics"] +mmlu_redux_2_college_medicine = _mmlu_redux_2_tasks["college_medicine"] +mmlu_redux_2_college_physics = _mmlu_redux_2_tasks["college_physics"] +mmlu_redux_2_computer_security = _mmlu_redux_2_tasks["computer_security"] +mmlu_redux_2_conceptual_physics = _mmlu_redux_2_tasks["conceptual_physics"] +mmlu_redux_2_econometrics = _mmlu_redux_2_tasks["econometrics"] +mmlu_redux_2_electrical_engineering = _mmlu_redux_2_tasks["electrical_engineering"] +mmlu_redux_2_elementary_mathematics = _mmlu_redux_2_tasks["elementary_mathematics"] +mmlu_redux_2_formal_logic = _mmlu_redux_2_tasks["formal_logic"] +mmlu_redux_2_global_facts = _mmlu_redux_2_tasks["global_facts"] +mmlu_redux_2_high_school_biology = _mmlu_redux_2_tasks["high_school_biology"] +mmlu_redux_2_high_school_chemistry = _mmlu_redux_2_tasks["high_school_chemistry"] +mmlu_redux_2_high_school_computer_science = _mmlu_redux_2_tasks["high_school_computer_science"] +mmlu_redux_2_high_school_european_history = _mmlu_redux_2_tasks["high_school_european_history"] +mmlu_redux_2_high_school_geography = _mmlu_redux_2_tasks["high_school_geography"] +mmlu_redux_2_high_school_government_and_politics = _mmlu_redux_2_tasks["high_school_government_and_politics"] +mmlu_redux_2_high_school_macroeconomics = _mmlu_redux_2_tasks["high_school_macroeconomics"] +mmlu_redux_2_high_school_mathematics = _mmlu_redux_2_tasks["high_school_mathematics"] +mmlu_redux_2_high_school_microeconomics = _mmlu_redux_2_tasks["high_school_microeconomics"] +mmlu_redux_2_high_school_physics = _mmlu_redux_2_tasks["high_school_physics"] +mmlu_redux_2_high_school_psychology = _mmlu_redux_2_tasks["high_school_psychology"] +mmlu_redux_2_high_school_statistics = _mmlu_redux_2_tasks["high_school_statistics"] +mmlu_redux_2_high_school_us_history = _mmlu_redux_2_tasks["high_school_us_history"] +mmlu_redux_2_high_school_world_history = _mmlu_redux_2_tasks["high_school_world_history"] +mmlu_redux_2_human_aging = _mmlu_redux_2_tasks["human_aging"] +mmlu_redux_2_human_sexuality = _mmlu_redux_2_tasks["human_sexuality"] +mmlu_redux_2_international_law = _mmlu_redux_2_tasks["international_law"] +mmlu_redux_2_jurisprudence = _mmlu_redux_2_tasks["jurisprudence"] +mmlu_redux_2_logical_fallacies = _mmlu_redux_2_tasks["logical_fallacies"] +mmlu_redux_2_machine_learning = _mmlu_redux_2_tasks["machine_learning"] +mmlu_redux_2_management = _mmlu_redux_2_tasks["management"] +mmlu_redux_2_marketing = _mmlu_redux_2_tasks["marketing"] +mmlu_redux_2_medical_genetics = _mmlu_redux_2_tasks["medical_genetics"] +mmlu_redux_2_miscellaneous = _mmlu_redux_2_tasks["miscellaneous"] +mmlu_redux_2_moral_disputes = _mmlu_redux_2_tasks["moral_disputes"] +mmlu_redux_2_moral_scenarios = _mmlu_redux_2_tasks["moral_scenarios"] +mmlu_redux_2_nutrition = _mmlu_redux_2_tasks["nutrition"] +mmlu_redux_2_philosophy = _mmlu_redux_2_tasks["philosophy"] +mmlu_redux_2_prehistory = _mmlu_redux_2_tasks["prehistory"] +mmlu_redux_2_professional_accounting = _mmlu_redux_2_tasks["professional_accounting"] +mmlu_redux_2_professional_law = _mmlu_redux_2_tasks["professional_law"] +mmlu_redux_2_professional_medicine = _mmlu_redux_2_tasks["professional_medicine"] +mmlu_redux_2_professional_psychology = _mmlu_redux_2_tasks["professional_psychology"] +mmlu_redux_2_public_relations = _mmlu_redux_2_tasks["public_relations"] +mmlu_redux_2_security_studies = _mmlu_redux_2_tasks["security_studies"] +mmlu_redux_2_sociology = _mmlu_redux_2_tasks["sociology"] +mmlu_redux_2_us_foreign_policy = _mmlu_redux_2_tasks["us_foreign_policy"] +mmlu_redux_2_virology = _mmlu_redux_2_tasks["virology"] +mmlu_redux_2_world_religions = _mmlu_redux_2_tasks["world_religions"] diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 73956c835..c54afb5fe 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -20,7 +20,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import inspect import logging import random from dataclasses import asdict, dataclass, field @@ -154,20 +153,28 @@ def __post_init__(self): self.stop_sequence = self.stop_sequence if self.stop_sequence is not None else () self.full_name = f"{self.name}|{self.num_fewshots}" # todo clefourrier: this is likely incorrect - def print(self): + def __str__(self, lite: bool = False): md_writer = MarkdownTableWriter() md_writer.headers = ["Key", "Value"] + # These keys change through time + to_ignore = ["original_num_docs", "effective_num_docs"] + values = [] for k, v in asdict(self).items(): - if k == "metric": + if lite and k in to_ignore: + continue + if k == "metrics": for ix, metrics in enumerate(v): for metric_k, metric_v in metrics.items(): - if inspect.ismethod(metric_v): - values.append([f"{k} {ix}: {metric_k}", metric_v.__qualname__]) + if isinstance(metric_v, Callable): + repr_v = metric_v.__name__ + elif isinstance(metric_v, Metric.get_allowed_types_for_metrics()): + repr_v = str(metric_v) else: - values.append([f"{k} {ix}: {metric_k}", repr(metric_v)]) + repr_v = repr(metric_v) + values.append([f"{k} {ix}: {metric_k}", repr_v]) else: if isinstance(v, Callable): @@ -177,7 +184,10 @@ def print(self): md_writer.value_matrix = values - print(md_writer.dumps()) + return md_writer.dumps() + + def print(self, lite: bool = False): + print(str(self, lite)) class LightevalTask: diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 70435d715..2059d2843 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -25,8 +25,7 @@ import json import logging import os -from dataclasses import asdict -from enum import Enum +from dataclasses import asdict, dataclass from pathlib import Path from typing import Callable, List, Set, Tuple, Union @@ -35,16 +34,31 @@ from lighteval.models.abstract_model import ModelConfig from lighteval.models.model_output import ModelResponse -from lighteval.tasks.requests import Doc +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.registry import Registry +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.utils import as_list logger = logging.getLogger(__name__) -class SampleType(Enum): - PREDICTIONS = 1 - TOKENIZED_INPUTS = 2 # Not implemented yet +@dataclass +class TaskID: + """A unique ID for a grouping of task samples. It relies on the task name, + the task config (which gives the task_hash), and the sampling method (linked to + the metric type) + """ + + task_name: str + task_hash: str + sampling_method: SamplingMethod + + def __str__(self): + return f"{self.task_name} ({self.task_hash}, {self.sampling_method.name})" + + def __hash__(self): + return int.from_bytes(hashlib.sha256(str(self).encode()).digest(), byteorder="big") class SampleCache: @@ -54,10 +68,10 @@ class SampleCache: Cache Structure: - {cache_dir}/ - - {sample_type}/ - {model_name}/ - - {model_hash}/ - - {task_name}.parquet + - {model_hash}/ + - {task_name}/ + - {task_hash}/ dataset dict, where splits are SamplingMethod """ def __init__(self, model_config: ModelConfig): @@ -67,49 +81,59 @@ def __init__(self, model_config: ModelConfig): model_config: Configuration for the model being cached cache_dir: Directory to store cache files """ - self.cache_dir = Path(os.path.expanduser(model_config.cache_dir)) self.model_config = model_config self.model_hash = self.get_model_hash(model_config) - # Create cache directory structure and load cached indices if present - self.all_cache_dirs = {} - self.existing_indices = {} - for sample_type in SampleType: - self.all_cache_dirs[sample_type] = ( - self.cache_dir / sample_type.name.lower() / self.model_config.model_name / self.model_hash - ) - self.all_cache_dirs[sample_type].mkdir(parents=True, exist_ok=True) - self.existing_indices[sample_type] = self._get_cached_indices(sample_type) + self.cache_dir = ( + Path(os.path.expanduser(self.model_config.cache_dir)) / self.model_config.model_name / self.model_hash + ) + self.cache_dir.mkdir(parents=True, exist_ok=True) - def _get_cached_indices(self, sample_type: SampleType) -> dict: - """Loads all indices for samples which are properly cached + self.registry = None + + self.existing_indices = self._load_cached_indices() + + def _init_registry(self, registry: Registry): + self.registry = registry + + def _load_cached_indices(self) -> dict: + """Loads all indices for samples which are properly cached. We recursively search for all available tasks and files. Returns: dict: Dictionary mapping task names to lists of cached sample indices """ + logger.info("[CACHING] Initializing data cache") cached_indices = {} - cache_dir = self.all_cache_dirs[sample_type] + cache_dir = self.cache_dir if not cache_dir.exists(): return cached_indices - for cache_file in cache_dir.glob("*.parquet"): - task_name = cache_file.stem + for cache_file in cache_dir.rglob("*.parquet"): try: - dataset = load_dataset("parquet", data_files=str(cache_file), split="train") + # cache_file.parts gives all the subfolders of the url, up to the file name + # last 3 are task_name/task_hash/file_name.parquet, so we take -3 and -2 + task_name, task_hash = cache_file.parts[-3:-1] + sampling_method = SamplingMethod[cache_file.stem] # removes the file extension + task_id = TaskID(task_name, task_hash, sampling_method) + + full_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") sample_ids = [] - for row in dataset: + for row in full_dataset: try: # We only save indices of correctly formatted samples, though this means we need to load each at least once - self._load_sample(row, sample_type=sample_type) - sample_ids.append(row["sample_id"]) + self._load_sample(row) + cur_sample = row["sample_id"] + sample_ids.append(cur_sample) except Exception: continue - cached_indices[task_name] = sample_ids - logger.debug(f"Loaded {len(sample_ids)} cached indices for task '{task_name}' from {cache_file}") + cached_indices[task_id] = sample_ids + logger.info( + f"[CACHING] Loaded {len(sample_ids)} cached indices for task '{str(task_id)} from {cache_file}" + ) except Exception as e: - logger.warning(f"Error loading cached indices for task '{task_name}' from {cache_file}: {e}") + logger.warning(f"Error loading cached indices from {cache_file}: {e}") return cached_indices @@ -124,21 +148,60 @@ def get_model_hash(self, model_config: ModelConfig) -> str: config_str = json.dumps(config_dict, sort_keys=True, default=str) return hashlib.sha256(config_str.encode()).hexdigest()[:16] - def get_cache_path(self, task_name: str, sample_type: SampleType) -> Path: + def _get_task_hash(self, full_task_name: str) -> str: + """Builds a task_hash from the LightevalTaskConfig loaded from the task name and the registry. + + Args: + full_task_name (str): task_name as provided to the registry (with suite|task|few_shot) + + Returns: + str: a hash of the task config in its current state in the registry, or the NO_HASH string if the + registry has not been preloaded + """ + if self.registry is None: + logger.warning( + "The task registry was not provided to the cache config. We can't test if the current task has the same hash as the saved tasks." + ) + return "NO_HASH" + task_suite, task_name, _ = full_task_name.split("|") + task_configs: list[LightevalTaskConfig] = sorted(self.registry.task_to_configs[f"{task_suite}|{task_name}"]) + config_str = "|".join([task_config.__str__(lite=True) for task_config in task_configs]) + return hashlib.sha256(config_str.encode()).hexdigest()[:16] + + def get_cache_path(self, task_id: TaskID) -> Path: """Get the file path for a specific task's cache file. Args: - task_name: Name of the task - sample_type: Type of samples being cached + task_id: TaskID of the task Returns: Path: Path to the cache file for the given task and sample type """ - return self.all_cache_dirs[sample_type] / f"{task_name}.parquet" + return self.cache_dir / task_id.task_name / task_id.task_hash / f"{task_id.sampling_method.name}.parquet" - def _load_sample( - self, sample: pd.core.series.Series | dict, sample_type: SampleType - ) -> Union[dict, ModelResponse]: + def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID: + """Returns a unique task indentifier. Depends on the task name, + task version and parameters (from which a hash is derived), and + current sampling method (current metric we look at). + + Args: + task_name (str): Name of the task + sampling_method (SamplingMethod): Sampling used for the current metric + + Returns: + TaskID: A unique identifier for the task + """ + task_hash = self._get_task_hash(task_name) + return TaskID(task_name, task_hash, sampling_method) + + def get_sampling_method(self, sample: dict) -> str: + if len(sample.get("logprobs", [])) > 0: + return SamplingMethod.LOGPROBS + if len(sample.get("text", [])) > 0: + return SamplingMethod.GENERATIVE + return None + + def _load_sample(self, sample: pd.core.series.Series | dict) -> Union[dict, ModelResponse]: """Load a sample from cached data based on sample type. Args: @@ -151,50 +214,49 @@ def _load_sample( # If we just use the pandas dict, lists are converted to np arrays which we don't want if isinstance(sample, pd.core.series.Series): sample = json.loads(sample.to_json()) - if sample_type == SampleType.TOKENIZED_INPUTS: - return sample["sample"] - elif sample_type == SampleType.PREDICTIONS: - return ModelResponse(**sample["sample"]) + return ModelResponse(**sample["sample"]) - def _dump_sample(self, result: Union[dict, ModelResponse], sample_type: SampleType) -> dict: + def _dump_sample(self, result: Union[dict, ModelResponse]) -> dict: """Dumps the sample in the correct format for file saving Args: result (Union[dict, ModelResponse]): Processed sample to save - sample_type (SampleType): Type of sample Returns: dict """ - if sample_type == SampleType.TOKENIZED_INPUTS: - return result - elif sample_type == SampleType.PREDICTIONS: - return asdict(result) + return asdict(result) - def get_notcached_samples(self, docs: List[Doc], sample_type: SampleType) -> Tuple[List[Doc], Set]: - """Identify which docs need processing based on cached indices. + def get_samples_to_process_and_cache( + self, docs: List[Doc], sampling_method: SamplingMethod + ) -> Tuple[List[Doc], Set[TaskID]]: + """ + Identify which docs need processing because they are not cached yet, based on cached doc and task indices. Returns: Tuple of (docs_not_cached, tasks_with_cached_samples) where - docs_not_cached contains docs that need processing - tasks_with_cached_samples are the tasks that have some cached samples """ - cached_indices = self.existing_indices[sample_type] + cached_indices = self.existing_indices docs_not_cached = [] tasks_with_cached_samples = set() for doc in docs: - task_name = doc.task_name - if task_name in cached_indices and doc.id in cached_indices[task_name]: - tasks_with_cached_samples.add(task_name) - else: + task_id = self.get_task_id(doc.task_name, sampling_method) + try: + if doc.id in cached_indices[task_id]: + tasks_with_cached_samples.add(task_id) + else: + docs_not_cached.append(doc) + except KeyError: # task id or sampling method not yet there docs_not_cached.append(doc) return docs_not_cached, set(tasks_with_cached_samples) def get_samples_from_cache( - self, docs: List[Doc], task_names: list | set, sample_type: SampleType + self, docs: List[Doc], task_ids: List[TaskID] | set[TaskID], sampling_method: SamplingMethod ) -> List[dict | ModelResponse]: """Get cached samples for the given docs. Warning: Assumes all docs and task_names provided are stored in cache, will fail otherwise. @@ -205,89 +267,96 @@ def get_samples_from_cache( # Load datasets for tasks that have cached docs task_datasets = {} - for task_name in task_names: - cache_file = self.get_cache_path(task_name=task_name, sample_type=sample_type) + for task_id in task_ids: + if task_id.sampling_method != sampling_method: + continue + cache_file = self.get_cache_path(task_id) try: dataset = load_dataset("parquet", data_files=str(cache_file), split="train") dataset_df = dataset.to_pandas().set_index("sample_id") - task_datasets[task_name] = dataset_df + task_datasets[task_id] = dataset_df except Exception as e: - logger.warning(f"Error loading {sample_type.name.lower()} cache for {task_name}: {e}") + logger.warning(f"Error loading prediction cache for {str(task_id)}: {e}") # Build results list results = [] for doc in docs: - row = task_datasets[doc.task_name].loc[doc.id] - results.append(self._load_sample(row, sample_type)) + task_id = self.get_task_id(doc.task_name, sampling_method) + row = task_datasets[task_id].loc[doc.id] + results.append(self._load_sample(row)) return results - def store_samples( + def cache_samples( # noqa C901 self, docs: List[Doc], results: List[dict] | List[ModelResponse], - task_names: list[str], - sample_type: SampleType, + task_ids: list[TaskID], + sampling_method: SamplingMethod, ): """Store new results for samples in docs""" if not results: return # Prepare newly processed data for dataset - processed_data = {task_name: [] for task_name in task_names} + processed_data = {task_id: [] for task_id in task_ids} for doc, result in zip(docs, results): - processed_data[doc.task_name].append( - {"sample_id": doc.id, "sample": self._dump_sample(result, sample_type)} - ) - processed_data = {task_name: task_data for task_name, task_data in processed_data.items() if task_data} + task_id = self.get_task_id(doc.task_name, sampling_method) + sample = self._dump_sample(result) + + processed_data[task_id].append({"sample_id": doc.id, "sample": sample}) + processed_data = {task_id: task_data for task_id, task_data in processed_data.items() if task_data} # Concatenate it with existing data and save to file - for task_name, task_data in processed_data.items(): - cache_file = self.get_cache_path(task_name=task_name, sample_type=sample_type) + for task_id, task_data in processed_data.items(): + if task_id not in self.existing_indices.keys(): + self.existing_indices[task_id] = {} + + cache_file = self.get_cache_path(task_id) # Load existing data if present existing_data = [] + existing_samples = {} if cache_file.exists(): try: existing_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") existing_data = existing_dataset.to_list() + except KeyError: + logger.info(f"No data was cached for {str(task_id)}") except Exception as e: - logger.error(f"Error loading existing {sample_type.name.lower()} cache for {task_name}: {e}") + logger.error(f"Error loading existing prediction cache for {str(task_id)}: {e}") - # Merge with new data (new data overwrites existing) - existing_ids = {row["sample_id"] for row in existing_data} + existing_samples = {(row["sample_id"], sampling_method) for row in existing_data} + if any((row["sample_id"], sampling_method) in existing_samples for row in task_data): + logger.warning( + "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version." + ) - if any(row["sample_id"] in existing_ids for row in task_data): - logger.warning( - "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version." - ) - all_samples = existing_data + [row for row in task_data if row["sample_id"] not in existing_ids] + # Merge with new data (new data overwrites existing) + # We look at id + sampling method + new_data = [row for row in task_data if (row["sample_id"], sampling_method) not in existing_samples] + all_samples = existing_data + new_data # Save updated dataset dataset = Dataset.from_list(all_samples) dataset.to_parquet(str(cache_file)) - logger.info( - f"Cached {len(all_samples)} {sample_type.name.lower()} samples of {task_name} at {str(cache_file)}." - ) + logger.info(f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}.") # Refresh cached indices after storing new samples - self.existing_indices[sample_type][task_name] = [sample["sample_id"] for sample in all_samples] + self.existing_indices[task_id] = [sample["sample_id"] for sample in all_samples] -def cached(cache_type_name: str): # noqa C901 - """Decorator to cache method results based on Doc inputs. +def cached(sampling_method: SamplingMethod = None): # noqa C901 + """ + Decorator to cache method results based on Doc inputs. Args: cache_type_name: Type of cache ("tokenization" or "predictions") Usage: - @cached("tokenization") - def tok_encode_pair(self, docs: List[Doc], ...): - # method implementation - - @cached("predictions") + @cached(SamplingMethod.GENERATIVE) def greedy_until(self, docs: List[Doc], ...): # method implementation @@ -298,7 +367,6 @@ def greedy_until(self, docs: List[Doc], ...): def decorator(func: Callable): # noqa C901 @functools.wraps(func) def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 - cache_type = SampleType[cache_type_name.upper()] docs = as_list(docs) # Check if caching is enabled for the model @@ -308,34 +376,46 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 cache: SampleCache = self._cache # Extract task names - task_names = {doc.task_name for doc in docs} + task_ids = {cache.get_task_id(doc.task_name, sampling_method) for doc in docs} # 1) Identify which samples must be processed because they are not cached - docs_not_cached, tasks_with_cached_samples = cache.get_notcached_samples(docs, cache_type) + docs_not_cached: List[Doc] + tasks_with_cached_samples: Set[TaskID] + docs_not_cached, tasks_with_cached_samples = cache.get_samples_to_process_and_cache(docs, sampling_method) # Log cache statistics cached_count = len(docs) - len(docs_not_cached) if cached_count > 0: logger.info( - f"Cache: {cached_count}/{len(docs)} {cache_type.name.lower()} samples are cached for tasks {', '.join(tasks_with_cached_samples)}" + f"Cache: {cached_count}/{len(docs)} samples are cached for tasks {', '.join(t_id.task_name for t_id in tasks_with_cached_samples)}" ) # 2) Process not cached docs and save to file new_results = [] if docs_not_cached: - notcached_task_names = {doc.task_name for doc in docs_not_cached} + tasks_needing_sample_processing = { + cache.get_task_id(doc.task_name, sampling_method) for doc in docs_not_cached + } logger.info( - f"Cache: Processing {len(docs_not_cached)}/{len(docs)} {cache_type.name.lower()} samples for tasks {', '.join(notcached_task_names)}" + f"Cache: Starting to process {len(docs_not_cached)}/{len(docs)} samples (not found in cache) for tasks {','.join(str(t) for t in tasks_needing_sample_processing)}" ) new_results = func(self, docs_not_cached, *args, **kwargs) # Store new results in file cache - cache.store_samples( - docs=docs_not_cached, results=new_results, task_names=task_names, sample_type=cache_type + cache.cache_samples( + docs=docs_not_cached, + results=new_results, + task_ids=task_ids, + sampling_method=sampling_method, ) # 3) Create final results by pulling from newly saved file cache - final_results = cache.get_samples_from_cache(docs, task_names, cache_type) + final_cached_results = cache.get_samples_from_cache(docs, task_ids, sampling_method) + + # 4) We only keep samples with the correct sampling method + final_results = [ + s for s in final_cached_results if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method + ] if any(r is None for r in final_results): raise ValueError("Problem while loading and aggregating items from cache.") diff --git a/tests/utils.py b/tests/utils.py index b44d27551..7954b3531 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,6 +34,7 @@ from lighteval.tasks.lighteval_task import LightevalTask from lighteval.tasks.registry import Registry from lighteval.tasks.requests import Doc +from lighteval.utils.cache_management import SampleCache from lighteval.utils.imports import is_accelerate_available @@ -55,6 +56,7 @@ def __init__( self.greedy_until_responses = greedy_until_responses self.loglikelihood_responses = loglikelihood_responses self.loglikelihood_rolling_responses = loglikelihood_rolling_responses + self._cache = SampleCache(self.config) @property def tokenizer(self): diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 1d8f6060d..47ea599ae 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -30,8 +30,8 @@ from lighteval.models.abstract_model import LightevalModel from lighteval.models.model_output import ModelResponse -from lighteval.tasks.requests import Doc -from lighteval.utils.cache_management import SampleCache, SampleType +from lighteval.tasks.requests import Doc, SamplingMethod +from lighteval.utils.cache_management import SampleCache class TestCaching(unittest.TestCase): @@ -95,12 +95,10 @@ def test_cache_directory_structure(self): cache = SampleCache(config) # Check directory structure - cache_dirs = list(cache.all_cache_dirs.values()) - self.assertEqual(len(cache_dirs), 2) - for folder in cache_dirs: - self.assertTrue(folder.exists()) - self.assertIn(str(temp_dir), str(folder)) - self.assertIn(model_name, str(folder)) + folder = cache.cache_dir + self.assertTrue(folder.exists()) + self.assertIn(str(temp_dir), str(folder)) + self.assertIn(model_name, str(folder)) def test_cache_decorator_presence(self): """Test that @cached decorators are present on the right methods.""" @@ -142,35 +140,37 @@ def test_cache_decorator_presence(self): hasattr(method, "__wrapped__"), f"{method_name} missing @cached decorator for {model_class}" ) - def _test_cache(self, model: LightevalModel): + def _test_cache(self, model: LightevalModel, test_cases): """Test that the @cached decorator logic works correctly - called by all model specific functions below.""" - for function_name in ["greedy_until", "loglikelihood", "loglikelihood_rolling"]: + for function_name, sampling_method in test_cases: with self.subTest(function_name=function_name): process_inputs = getattr(model, function_name) process_inputs(self.docs) cache: SampleCache = model._cache + # Check task_id + task_id = cache.get_task_id(self.task_name, sampling_method) + self.assertEqual(task_id.task_name, self.task_name) + self.assertEqual(task_id.sampling_method, sampling_method) + # Verify cache files were created - cache_file = cache.get_cache_path(task_name=self.task_name, sample_type=SampleType.PREDICTIONS) + cache_file = cache.get_cache_path(task_id) self.assertTrue(cache_file.exists(), "Cache file not created") # Test retrieving from cache - self.assertEqual( - cache._get_cached_indices(SampleType.PREDICTIONS), {self.task_name: [doc.id for doc in self.docs]} + self.assertEqual(cache._load_cached_indices()[task_id], [doc.id for doc in self.docs]) + uncached_docs, tasks_with_cached_samples = cache.get_samples_to_process_and_cache( + docs=self.docs, sampling_method=sampling_method ) - uncached_docs, tasks_with_cached_samples = cache.get_notcached_samples( - docs=self.docs, sample_type=SampleType.PREDICTIONS - ) - - self.assertEqual(tasks_with_cached_samples, {self.task_name}) + self.assertEqual(tasks_with_cached_samples, {task_id}) self.assertEqual( len(uncached_docs), 0, f"{len(uncached_docs)} documents not found in cache when it should be 0" ) # Verify cached results match original cached_responses = cache.get_samples_from_cache( - docs=self.docs, task_names=[self.task_name], sample_type=SampleType.PREDICTIONS + docs=self.docs, task_ids=[task_id], sampling_method=sampling_method ) for cached_response, response in zip(cached_responses, self.model_responses): self.assertEqual(asdict(cached_response), asdict(response)) @@ -200,11 +200,18 @@ def test_cache_transformers( config = TransformersModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir) model = TransformersModel(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ("loglikelihood", SamplingMethod.LOGPROBS), + ("loglikelihood_rolling", SamplingMethod.PERPLEXITY), + ], + ) - @patch("lighteval.models.vllm.vllm_model.VLLMModel._create_auto_model") - @patch("lighteval.models.vllm.vllm_model.VLLMModel._greedy_until") @patch("lighteval.models.vllm.vllm_model.VLLMModel._loglikelihood_tokens") + @patch("lighteval.models.vllm.vllm_model.VLLMModel._greedy_until") + @patch("lighteval.models.vllm.vllm_model.VLLMModel._create_auto_model") def test_cache_vllm(self, mock_create_model, mock_greedy_until, mock_loglikelihood): from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig @@ -217,7 +224,13 @@ def test_cache_vllm(self, mock_create_model, mock_greedy_until, mock_loglikeliho config = VLLMModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir) model = VLLMModel(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ("loglikelihood", SamplingMethod.LOGPROBS), + ], + ) @patch("requests.get") @patch("lighteval.models.endpoints.tgi_model.ModelClient._greedy_until") @@ -242,7 +255,14 @@ def test_cache_tgi(self, mock_loglikelihood, mock_greedy_until, mock_requests_ge ) model = ModelClient(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ("loglikelihood", SamplingMethod.LOGPROBS), + ("loglikelihood_rolling", SamplingMethod.PERPLEXITY), + ], + ) @patch("lighteval.models.endpoints.endpoint_model.InferenceEndpointModel._loglikelihood") @patch("lighteval.models.endpoints.endpoint_model.InferenceEndpointModel._greedy_until") @@ -263,7 +283,14 @@ def test_cache_endpoint(self, mock_init, mock_greedy_until, mock_loglikelihood): config = InferenceEndpointModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir) model = InferenceEndpointModel(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ("loglikelihood", SamplingMethod.LOGPROBS), + ("loglikelihood_rolling", SamplingMethod.PERPLEXITY), + ], + ) @patch("lighteval.models.sglang.sglang_model.SGLangModel._loglikelihood_tokens") @patch("lighteval.models.sglang.sglang_model.SGLangModel._greedy_until") @@ -284,7 +311,13 @@ def test_cache_sglang( config = SGLangModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir) model = SGLangModel(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ("loglikelihood", SamplingMethod.LOGPROBS), + ], + ) @patch("lighteval.models.transformers.vlm_transformers_model.VLMTransformersModel._greedy_until") @patch("lighteval.utils.imports.is_accelerate_available") @@ -312,4 +345,9 @@ def test_cache_vlm_transformers( config = VLMTransformersModelConfig(model_name="HuggingFaceTB/SmolVLM-256M-Instruct", cache_dir=temp_dir) model = VLMTransformersModel(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ], + )