diff --git a/lm_eval/_cli/utils.py b/lm_eval/_cli/utils.py index b2a47fea009..659b7afe70c 100644 --- a/lm_eval/_cli/utils.py +++ b/lm_eval/_cli/utils.py @@ -112,7 +112,7 @@ def key_val_to_dict(args: str) -> dict[str, Any]: if not args: return res - for k, v in (item.split("=",1) for item in args.split(",")): + for k, v in (item.split("=", 1) for item in args.split(",")): v = handle_cli_value_string(v) if k in res: eval_logger.warning(f"Overwriting key '{k}': {res[k]!r} -> {v!r}") diff --git a/lm_eval/api/model.py b/lm_eval/api/model.py index b9e7bd31c4e..d22a65ed48f 100644 --- a/lm_eval/api/model.py +++ b/lm_eval/api/model.py @@ -3,7 +3,8 @@ import json import logging import os -from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional, TypeVar from tqdm import tqdm @@ -31,7 +32,7 @@ def __init__(self) -> None: # set rank and world size to a single process, by default. self._rank = 0 self._world_size = 1 - self.cache_hook: "CacheHook" = CacheHook(None) + self.cache_hook: CacheHook = CacheHook(None) @abc.abstractmethod def loglikelihood(self, requests) -> list[tuple[float, bool]]: @@ -68,7 +69,8 @@ def loglikelihood_rolling(self, requests) -> list[float]: which may simply concatenate multiple documents together. - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into multiple chunks, the last input will still a full-sized context. - Example: + + Example: Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] Prefix: BOS/EOS Max context length: 4 @@ -137,7 +139,7 @@ def apply_chat_template( @classmethod def create_from_arg_string( - cls: Type[T], arg_string: str, additional_config: Optional[dict] = None + cls: type[T], arg_string: str, additional_config: dict | None = None ) -> T: """ Creates an instance of the LM class using the given argument string and additional config. @@ -156,7 +158,7 @@ def create_from_arg_string( @classmethod def create_from_arg_obj( - cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None + cls: type[T], arg_dict: dict, additional_config: dict | None = None ) -> T: """ Creates an instance of the LM class using the given arg_obj @@ -201,7 +203,7 @@ def tokenizer_name(self) -> str: "To use this model with chat templates, please implement the 'tokenizer_name' property." ) - def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: + def chat_template(self, chat_template: bool | str = False) -> str | None: """Returns the chat template structure for user/assistant messages if a template is provided. This method is intended to be overridden in a subclass to define a specific chat template format. For models that do not support chat templates, this method returns None by default. @@ -222,7 +224,7 @@ def hash_args(attr: str, args: Iterable[Any]) -> str: class CacheHook: def __init__(self, cachinglm: Optional["CachingLM"]) -> None: if cachinglm is None: - self.dbdict: Optional["SqliteDict"] = None + self.dbdict: SqliteDict | None = None return self.dbdict = cachinglm.dbdict @@ -292,15 +294,12 @@ def _fn(requests: list["Instance"]) -> list["Instance"]: eval_logger.info( f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" ) - if remaining_reqs: - # actually run the LM on the requests that do not have cached results - rem_res = getattr(self.lm, attr)(remaining_reqs) - else: - rem_res = [] + # actually run the LM on the requests that do not have cached results + rem_res = getattr(self.lm, attr)(remaining_reqs) if remaining_reqs else [] # stick the new ones back into the list and also cache any of the new ones resptr = 0 - for req, r in zip(remaining_reqs, rem_res): + for req, r in zip(remaining_reqs, rem_res, strict=True): while res[resptr] is not None: resptr += 1 @@ -340,7 +339,7 @@ def prefix_token_id(self): @abc.abstractmethod def tok_encode( - self, string: str, add_special_tokens: Optional[bool] = None, **kwargs + self, string: str, add_special_tokens: bool | None = None, **kwargs ) -> list[int]: """ Tokenize a string using the model's tokenizer and return a list of token IDs. @@ -351,7 +350,7 @@ def tok_encode( @abc.abstractmethod def _loglikelihood_tokens( - self, requests: list["Instance"], **kwargs + self, requests: list[tuple[tuple[str, str], list[int], list[int]]], **kwargs ) -> list[tuple[float, bool]]: pass @@ -462,7 +461,7 @@ def loglikelihood_rolling( def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: pass - def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: + def chat_template(self, chat_template: bool | str = False) -> str | None: """ Set and get the appropriate chat template for the model. This method sets the tokenizer's chat_template and returns the template string for reproducibility. diff --git a/lm_eval/config/task.py b/lm_eval/config/task.py index 79c7eecc1a9..46601a5663a 100644 --- a/lm_eval/config/task.py +++ b/lm_eval/config/task.py @@ -5,6 +5,8 @@ from inspect import getsource from typing import TYPE_CHECKING, Any +from lm_eval.defaults import default_gen_kwargs + if TYPE_CHECKING: from collections.abc import Callable @@ -145,15 +147,7 @@ def __post_init__(self) -> None: else: if self.output_type == "generate_until": # ensure that we greedily generate in absence of explicit arguments otherwise - self.generation_kwargs = { - "until": ( - None - if self.fewshot_delimiter is None - else [self.fewshot_delimiter] - ), - "do_sample": False, - "temperature": 0, - } + self.generation_kwargs = default_gen_kwargs(self.fewshot_delimiter) eval_logger.warning( f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}" ) diff --git a/lm_eval/defaults.py b/lm_eval/defaults.py new file mode 100644 index 00000000000..bc64c44ec30 --- /dev/null +++ b/lm_eval/defaults.py @@ -0,0 +1,23 @@ +from typing import Any + + +DEFAULT_MAX_LENGTH = 2048 +DEFAULT_MAX_GEN_TOKS = 256 +DEFAULT_RANDOM_SEED = 0 +DEFAULT_OTHER_SEED = 1234 + + +def default_gen_kwargs( + until: str | list[str] | None, max_gen_toks: int = DEFAULT_MAX_GEN_TOKS +) -> dict[str, Any]: + """Returns default generation kwargs for LM evaluation.""" + _gen = { + "temperature": 0.0, + "do_sample": False, + "max_gen_toks": max_gen_toks, + } + if until is not None: + _gen["until"] = [until] if isinstance(until, str) else until + else: + _gen["until"] = [] + return _gen diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 197936cc334..81fe2514453 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -16,6 +16,7 @@ import lm_eval.api.registry import lm_eval.api.task from lm_eval.caching.cache import delete_cache +from lm_eval.defaults import DEFAULT_OTHER_SEED, DEFAULT_RANDOM_SEED from lm_eval.evaluator_utils import ( consolidate_group_results, consolidate_results, @@ -75,10 +76,10 @@ def simple_evaluate( task_manager: TaskManager | None = None, verbosity=None, predict_only: bool = False, - random_seed: int = 0, - numpy_random_seed: int = 1234, - torch_random_seed: int = 1234, - fewshot_random_seed: int = 1234, + random_seed: int = DEFAULT_RANDOM_SEED, + numpy_random_seed: int = DEFAULT_OTHER_SEED, + torch_random_seed: int = DEFAULT_OTHER_SEED, + fewshot_random_seed: int = DEFAULT_OTHER_SEED, confirm_run_unsafe_code: bool = False, metadata: dict | None = None, ): diff --git a/lm_eval/models/huggingface.py b/lm_eval/models/huggingface.py index b7f223f7e3c..fd8e5c0ac10 100644 --- a/lm_eval/models/huggingface.py +++ b/lm_eval/models/huggingface.py @@ -1,6 +1,5 @@ from __future__ import annotations -import copy import logging import os from datetime import timedelta @@ -35,6 +34,7 @@ configure_pad_token, handle_stop_sequences, has_bos_prefix, + normalize_gen_kwargs, postprocess_generated_text, ) from lm_eval.models.utils_hf import ( @@ -900,6 +900,7 @@ def tok_batch_encode( else: add_special_tokens = {} + assert self.tokenizer, "Tokenizer shoukd be initialized at this point" encoding = self.tokenizer( strings, truncation=truncation, @@ -982,11 +983,11 @@ def _model_generate( do_sample = generation_kwargs.get("do_sample") # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies - if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + if (temp := generation_kwargs.get("temperature")) == 0.0 and do_sample is None: generation_kwargs["do_sample"] = do_sample = False - if do_sample is False and generation_kwargs.get("temperature") == 0.0: - generation_kwargs.pop("temperature") + if do_sample is False and temp == 0.0: + generation_kwargs.pop("temperature", None) # build stopping criteria stopping_criteria = stop_sequences_criteria( self.tokenizer, stop, context.shape[1], context.shape[0] @@ -1136,7 +1137,7 @@ def _loglikelihood_tokens( requests: list[tuple[tuple[str, str], list[int], list[int]]], disable_tqdm: bool = False, override_bs: int | None = None, - ) -> list[tuple[float, bool]]: + ) -> list[tuple[float, bool]]: # type: ignore[invalid-method-override] # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context res = [] @@ -1275,11 +1276,13 @@ def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]): # create encoder attn mask and batched conts, if seq2seq call_kwargs = {} + assert padding_len_inp, "padding_len_inp should be set by now" if self.backend == "causal": batched_inps = pad_and_concat( padding_len_inp, inps, padding_side="right" ) # [batch, padding_len_inp] elif self.backend == "seq2seq": + assert padding_len_cont, "padding_len_cont should be set by now" # TODO: left-pad encoder inps and mask? batched_inps = pad_and_concat( padding_len_inp, inps @@ -1425,19 +1428,13 @@ def _collate(req: tuple[str, dict]): # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] - # unpack our keyword arguments. - if isinstance(gen_kwargs, dict): - kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 - # add EOS token to stop sequences - until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) - else: - raise TypeError( - f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" - ) - if "max_gen_toks" in kwargs: - max_gen_toks = kwargs.pop("max_gen_toks") - else: - max_gen_toks = self.max_gen_toks + assert isinstance(gen_kwargs, dict), ( + f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" + ) + kwargs = normalize_gen_kwargs(gen_kwargs, self.max_gen_toks) + # add EOS token to stop sequences + until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) + max_gen_toks = kwargs.pop("max_gen_toks") # set the max length in tokens of inputs ("context_enc") if self.backend == "causal": @@ -1459,14 +1456,19 @@ def _collate(req: tuple[str, dict]): context_enc = context_enc.to(self.device) attn_masks = attn_masks.to(self.device) - if "max_length" not in kwargs: - kwargs["max_length"] = context_enc.shape[1] + max_gen_toks + # max_length = context + generation tokens + if "max_length" in kwargs: + eval_logger.warning( + "`max_length` in generation kwargs. Please use `max_gen_toks` instead." + ) + max_length = kwargs.pop("max_length", context_enc.shape[1] + max_gen_toks) # type: ignore # perform batched generation cont = self._model_generate( context=context_enc, attention_mask=attn_masks, stop=until, + max_length=max_length, **kwargs, ) diff --git a/lm_eval/models/utils.py b/lm_eval/models/utils.py index 38bec519eb7..de887ddbf50 100644 --- a/lm_eval/models/utils.py +++ b/lm_eval/models/utils.py @@ -13,6 +13,10 @@ TypeVar, ) +from typing_extensions import TypedDict + +from lm_eval.utils import maybe_warn, warning_once + eval_logger = logging.getLogger(__name__) T = TypeVar("T") @@ -26,6 +30,15 @@ from transformers.configuration_utils import PretrainedConfig +class GenKwargs(TypedDict, total=False): + do_sample: bool + temperature: float + # other alias' will be converted to `max_gen_toks`. + max_gen_toks: int + until: list[str] + __extra_items__: Any + + def chunks(iter, n: int = 0, fn=None): """ Divides an iterable into chunks of specified size or based on a given function. @@ -605,13 +618,109 @@ def handle_stop_sequences(until: str | list[str] | None, eos: str | None) -> lis return until +def normalize_gen_kwargs( + gen_kwargs: dict, + default_max_gen_toks: int = 256, +) -> GenKwargs: + """Normalize generation kwargs for consistent handling across model backends. + + Model implementations may have different expectations for generation parameters. + + Args: + gen_kwargs: Raw generation kwargs from the request. Expected keys include: + - do_sample: Whether to use sampling (vs greedy decoding) - Required + - until (str | list[str]): Stop sequence(s) for generation. + - max_gen_toks | max_new_tokens | max_tokens | max_completion_tokens: Maximum tokens to generate + - temperature: Sampling temperature + - Other backend-specific kwargs + default_max_gen_toks: Default max_gen_toks if not specified in gen_kwargs. + + Returns: + A normalized dict containing: + - do_sample (bool): Whether to use sampling (bool) + - until: list[str]: List of stop sequences. + - max_gen_toks (int): Maximum tokens to generate (int) + - temperature (float): Sampling temperature (float). Note: will always be set to 0.0 if do_sample=False or do_sample is not specified. + - All other kwargs passed through unchanged + + Notes: + - Accepts `max_gen_toks` and other aliases. Priority: + max_gen_toks > max_new_tokens > max_tokens > max_completion_tokens. + Output always uses `max_gen_toks`. + - When `do_sample=False`, temperature is set to 0.0 for greedy decoding. + - When temperature is 0.0 and `do_sample` is not specified, `do_sample` is set + to False. + - Model backends may further modify the returned dict as needed (e.g., vLLM + removes `do_sample` since it uses temperature directly). + """ + + import copy + + kwargs = copy.deepcopy(gen_kwargs) + + until = kwargs.get("until", []) + if not isinstance(until, list): + until = [until] + + # Extract max_gen_toks from various aliases (priority order: max_gen_toks > max_new_tokens > max_tokens > max_completion_tokens) + max_token_aliases = { + "max_gen_toks": kwargs.pop("max_gen_toks", None), + "max_new_tokens": kwargs.pop("max_new_tokens", None), # used in HF + "max_tokens": kwargs.pop( + "max_tokens", None + ), # used by vllm, OpenAI API's and others + "max_completion_tokens": kwargs.pop( + "max_completion_tokens", None + ), # newer OpenAI API alias + # note: `max_length` is also used by HF but has different semantics (prompt + generation) + } + provided = {k: v for k, v in max_token_aliases.items() if v is not None} + + if len(provided) > 1: + warning_once( + eval_logger, + f"Multiple max token args provided: {provided}. Using first by priority (max_gen_toks > max_new_tokens > max_tokens > max_completion_tokens).", + ) + + max_gen_toks = int(next(iter(provided.values()), default_max_gen_toks)) + + # Handle do_sample and temperature consistently + do_sample: bool | None = kwargs.get("do_sample") + temperature: float | None = float(kwargs.get("temperature", 0.0)) + + match do_sample: + case None: + kwargs["do_sample"] = True if temperature > 0.0 else False # noqa: SIM210 + # do_sample=False -> temperature=0.0 + case False: + if temperature and temperature != 0.0: + warning_once( + eval_logger, + f"{do_sample=}` but {temperature=}; setting `temperature` to 0.0 for greedy decoding. For non-greedy decoding, set `do_sample=True`.", + ) + kwargs["temperature"] = 0.0 + case True: + # do_sample=True -> use provided kwargs + if temperature == 0.0: + warning_once( + eval_logger, + f"{do_sample=}` but {temperature=}. For non-greedy sampling, set temperature > 0.0", + ) + + # Set normalized values + kwargs["until"] = until + kwargs["max_gen_toks"] = max_gen_toks + + return GenKwargs(**kwargs) # type:ignore[missing-typed-dict-key] + + def resize_image( image: Image.Image, width: int | None = None, height: int | None = None, max_dimension: int | None = None, keep_aspect_ratio: bool = True, - resample_filter: int | str = "Image.BICUBIC", + resample_filter: int | None = None, min_width: int = 1, min_height: int = 1, ) -> Image.Image: @@ -708,19 +817,94 @@ def resize_image( def truncate_tokens( tokens: list[int], max_length: int, - tokenizer: PreTrainedTokenizerBase, - strategy: str = "left", -): - if strategy == "left": - return tokens[-max_length:] - elif strategy == "right": - return tokens[:max_length] - elif strategy == "middle": - # Truncate the middle of the sequence - left_length = max_length // 2 - right_length = max_length - left_length - return tokens[:left_length] + tokens[-right_length:] - return None + side: Literal["left", "middle", "right"] = "left", +) -> list[int]: + """Truncate a token list to max_length using the given strategy (left, right, or middle).""" + # fmt: off + match side: + case "left": return tokens[-max_length:] + case "right": return tokens[:max_length] + case "middle": + # Truncate the middle of the sequence + left_length = max_length // 2 + right_length = max_length - left_length + return tokens[:left_length] + tokens[-right_length:] + case _: raise ValueError(f"Unknown truncation {side=}. Must be one of 'left', 'middle', or 'right'.") + # fmt: on + + +def maybe_truncate( + tokens: list[int], + max_gen_toks: int, + max_model_len: int, + min_gen_toks: int = 1, + side: Literal["left", "middle", "right"] = "left", + shrink_gen_toks=False, + verbose=True, +) -> tuple[list[int], int]: + """ + Truncates input tokens and/or reduces max_gen_toks to fit within max_model_len. + + Strategy: + 1. No truncation needed: If len(tokens) + max_gen_toks <= max_model_len, return as-is. + 2. If shrink_gen_toks=False: Truncate context to fit max_model_len - max_gen_toks. + 3. If shrink_gen_toks=True: + a. First try reducing max_gen_toks (down to min_gen_toks) to fit the context. + b. If context still doesn't fit, truncate context to reserve space for min_gen_toks. + + Args: + tokens (list[int]): The input context tokens to potentially truncate. + max_gen_toks (int): The maximum number of tokens to generate. + max_model_len (int): The model's maximum context window size (prompt + generation). + min_gen_toks (int): Lower bound for generation tokens. Defaults to 1. + side (str): "left" | "right" | "middle". Defaults to "left". + shrink_gen_toks (bool): Whether to adjust the generation tokens count + to fit within the maximum length. Defaults to False. + verbose (bool): Whether to log warnings when truncation or adjustments occur. + + Returns: + tuple[list[int], int]: A tuple containing: + - list[int]: The (possibly truncated) context tokens. + - int: The adjusted maximum generation token count. + + Raises: + ValueError: when max_model_len <= min_gen_toks. + """ + ctx_len = len(tokens) + + # Case 1: Everything fits comfortably + if ctx_len + max_gen_toks <= max_model_len: + return tokens, max_gen_toks + + warning = f"Context length ({ctx_len}) + max_gen_toks ({max_gen_toks}) = {ctx_len + max_gen_toks} exceeds model's max length ({max_model_len})" + + # Case 2: Do not adjust generation tokens, just truncate prompt + if not shrink_gen_toks: + maybe_warn(f"{warning}. Truncating context from {side=}.", verbose) + return truncate_tokens( + tokens, max_model_len - max_gen_toks, side=side + ), max_gen_toks + + # Case 3: Prompt fits, but need to reduce max_tokens + if (new_max := max_model_len - ctx_len) >= min_gen_toks: + maybe_warn( + f"{warning}. Reducing {max_gen_toks=} to {new_max} to fit within model context window.", + verbose, + ) + return tokens, new_max + + # Case 4: Need to truncate prompt to fit min_tokens + # Reserve space for min_tokens, use rest for prompt + if (max_ctx_len := max_model_len - min_gen_toks) <= 0: + raise ValueError( + f"Model context window ({max_model_len}) is too small to fit " + f"initial context len ({ctx_len}) + minimum generation len ({min_gen_toks})" + ) + maybe_warn( + f"{warning}. Truncating context from {side=} to {max_ctx_len} tokens to reserve {min_gen_toks=} for generation.", + verbose, + ) + return truncate_tokens(tokens, max_ctx_len, side=side), min_gen_toks def postprocess_generated_text( diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index 083e45334b6..c14dc61ff87 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -27,6 +27,8 @@ configure_pad_token, handle_stop_sequences, has_bos_prefix, + maybe_truncate, + normalize_gen_kwargs, postprocess_generated_text, undistribute, ) @@ -56,6 +58,7 @@ from transformers import PreTrainedTokenizerBase from lm_eval.api.instance import Instance + from lm_eval.models.utils import GenKwargs eval_logger = logging.getLogger(__name__) @@ -137,7 +140,7 @@ def __init__( quantization: str | None = None, max_gen_toks: int = 256, swap_space: int = 4, - batch_size: str | int = 1, + batch_size: str | int = "auto", max_batch_size=None, max_length: int | None = None, max_model_len: int | None = None, @@ -151,6 +154,7 @@ def __init__( # End marker for thinking tags - splits to get response after this token (if provided). think_end_token: str | None = None, max_lora_rank: int = 16, + truncation_side: Literal["left", "right", "middle"] = "left", **kwargs, ): super().__init__() @@ -169,6 +173,8 @@ def __init__( self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0" self._max_length = max_model_len if max_model_len is not None else max_length self.tensor_parallel_size = int(tensor_parallel_size) + # truncation strategy for inputs exceeding max length + self.truncation_side = truncation_side self.data_parallel_size = int(data_parallel_size) self.model_args = { "model": pretrained, @@ -195,7 +201,7 @@ def __init__( else int(batch_size) ) if self.data_parallel_size <= 1: - self.model = LLM(**self.model_args) + self.model = LLM(**self.model_args) # type: ignore[invalid-argument-type] else: eval_logger.warning( "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached." @@ -479,7 +485,7 @@ def run_inference_one_model( # We use Process as it is non-daemonic try: for rank, (req, sp) in enumerate( - zip(requests, sampling_params, strict=True) + zip(requests, sampling_params, strict=True) # type: ignore[invalid-argument-type] ): # type:ignore[invalid-argument-type] proc = Process( target=_vllm_mp_worker, @@ -660,34 +666,29 @@ def _collate_gen(_requests): context, context_encoding = zip(*context_and_encoding, strict=True) context_encoding_truncated = [] sampling_params = [] - for x, gen_kwargs in zip(context_encoding, all_gen_kwargs, strict=True): - # unpack our keyword arguments. - if isinstance(gen_kwargs, dict): - kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 - # add EOS token to stop sequences - until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) - else: - raise ValueError( - f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" - ) + for toks, gen_kwargs in zip(context_encoding, all_gen_kwargs, strict=True): + assert isinstance(gen_kwargs, dict), ( + f"Expected `gen_kwargs` to be of type `dict` but got {type(gen_kwargs)}" + ) - if "max_gen_toks" in kwargs: - max_gen_toks = int(kwargs.pop("max_gen_toks")) - else: - max_gen_toks = int(kwargs.pop("max_tokens", self.max_gen_toks)) + gen_kwargs = normalize_gen_kwargs( + gen_kwargs, default_max_gen_toks=self.max_gen_toks + ) + kwargs, until, max_gen_toks = self.modify_gen_kwargs( + gen_kwargs, eos=eos, default_max_gen_toks=self.max_gen_toks + ) # set the max length in tokens of inputs ("context_enc") # max len for inputs = max length, minus room to generate the max new tokens - max_ctx_len = self.max_length - max_gen_toks - if len(x) > max_ctx_len: - eval_logger.warning( - f"Context length {len(x)} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context." - ) - context_encoding_truncated.append(x[-max_ctx_len:]) - else: - context_encoding_truncated.append(x) - # create sampling params - kwargs = self.modify_gen_kwargs(kwargs) + toks, max_gen_toks = maybe_truncate( + toks, + max_gen_toks=max_gen_toks, + max_model_len=self.max_length, + side=self.truncation_side, + verbose=True, + ) + context_encoding_truncated.append(toks) + sampling_params.append( SamplingParams(max_tokens=max_gen_toks, stop=until, **kwargs) ) @@ -841,21 +842,40 @@ def coerce_logprob_to_num(logprob): return continuation_logprobs, is_greedy @staticmethod - def modify_gen_kwargs(kwargs: dict) -> dict: - # sampling_params - kwargs["temperature"] = kwargs.get("temperature", 0.0) - do_sample = kwargs.pop("do_sample", None) - if do_sample is False and "temperature" not in kwargs: - eval_logger.debug( - "Got `do_sample=False` and no temperature value, setting VLLM temperature to 0.0 ..." - ) - kwargs["temperature"] = 0.0 - # hf defaults + def modify_gen_kwargs( + gen_kwargs: GenKwargs, + eos: str | list[str] | None = None, + default_max_gen_toks: int = 256, + ) -> tuple[dict, list[str], int]: + """Process generation kwargs into vLLM-compatible format. + + Args: + gen_kwargs: Raw generation kwargs from the request. + eos: EOS token string for stop sequence handling. + default_max_gen_toks: Default max tokens if not specified in gen_kwargs. + + Returns: + A tuple of (kwargs, stop_sequences, max_gen_toks) where: + - kwargs: Processed kwargs ready for SamplingParams + - stop_sequences: List of stop sequences including EOS + - max_gen_toks: Maximum tokens to generate + """ + kwargs = {**copy.deepcopy(gen_kwargs)} + + # Extract and process stop sequences + until = handle_stop_sequences( + kwargs.pop("until", None), eos=eos[0] if isinstance(eos, list) else eos + ) + + # Extract max_tokens + max_gen_toks = int(kwargs.pop("max_gen_toks", default_max_gen_toks)) + + # do_sample and temperature normalization is handled by `normalize_gen_kwargs` utility + kwargs.pop("do_sample", None) + # HF defaults kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False) kwargs["spaces_between_special_tokens"] = kwargs.get( "spaces_between_special_tokens", False ) - # remove `max_gen_toks` - if "max_gen_toks" in kwargs: - del kwargs["max_gen_toks"] - return kwargs + + return kwargs, until, max_gen_toks diff --git a/lm_eval/utils.py b/lm_eval/utils.py index 9dcdf79c22d..2cc4f075f68 100644 --- a/lm_eval/utils.py +++ b/lm_eval/utils.py @@ -119,6 +119,13 @@ def info_once(logger: logging.Logger, msg: str, *args): logger.info(msg, *args) +def maybe_warn(msg: str, verbose: bool = True): + """Log a warning message only when verbose is True, otherwise noop.""" + if verbose: + logger = logging.getLogger(__name__) + logger.warning(msg) + + def hash_string(string: str) -> str: return hashlib.sha256(string.encode("utf-8")).hexdigest() diff --git a/pyproject.toml b/pyproject.toml index ec9f87d947d..8c91f1717b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "dill", "word2number", "more_itertools", + "typing_extensions" ] [tool.setuptools.packages.find] diff --git a/tests/models/test_model_utils.py b/tests/models/test_model_utils.py new file mode 100644 index 00000000000..5543d14b0af --- /dev/null +++ b/tests/models/test_model_utils.py @@ -0,0 +1,301 @@ +import pytest + +from lm_eval.models.utils import maybe_truncate, normalize_gen_kwargs, truncate_tokens + + +class TestTruncateTokens: + def test_left(self): + tokens = [1, 2, 3, 4, 5] + assert truncate_tokens(tokens, 3, side="left") == [3, 4, 5] + + def test_right(self): + tokens = [1, 2, 3, 4, 5] + assert truncate_tokens(tokens, 3, side="right") == [1, 2, 3] + + def test_middle(self): + tokens = [1, 2, 3, 4, 5] + # max_length=3: left_length=1, right_length=2 -> [1] + [4, 5] + assert truncate_tokens(tokens, 3, side="middle") == [1, 4, 5] + + def test_middle_even(self): + tokens = [1, 2, 3, 4, 5, 6] + # max_length=4: left_length=2, right_length=2 -> [1, 2] + [5, 6] + assert truncate_tokens(tokens, 4, side="middle") == [1, 2, 5, 6] + + def test_no_truncation_needed(self): + tokens = [1, 2, 3] + assert truncate_tokens(tokens, 5, side="left") == [1, 2, 3] + + def test_unknown_strategy(self): + with pytest.raises(ValueError) as execinfo: + truncate_tokens([1, 2, 3], 2, side="unknown") # type: ignore + assert "Unknown truncation side" in str(execinfo.value) + + +class TestMaybeTruncate: + """Tests for maybe_truncate with different truncation strategies.""" + + # Case 1: Everything fits + def test_case1_no_truncation(self): + tokens = [1, 2, 3, 4, 5] + result_tokens, result_gen = maybe_truncate( + tokens, max_gen_toks=5, max_model_len=10 + ) + assert result_tokens == [1, 2, 3, 4, 5] + assert result_gen == 5 + + def test_case1_no_truncation_with_adjust(self): + tokens = [1, 2, 3, 4, 5] + result_tokens, result_gen = maybe_truncate( + tokens, max_gen_toks=5, max_model_len=10, shrink_gen_toks=True + ) + assert result_tokens == [1, 2, 3, 4, 5] + assert result_gen == 5 + + # Case 2: shrink_gen_toks=False — truncate prompt to max_len - max_gen_toks, keep max_gen_toks + def test_case2_truncate_prompt_no_adjust(self): + tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + result_tokens, result_gen = maybe_truncate( + tokens, max_gen_toks=5, max_model_len=6, shrink_gen_toks=False + ) + # Left-truncates prompt to max_len - max_gen_toks = 1, keeps max_gen_toks=5 + assert result_tokens == [10] + assert result_gen == 5 + + def test_case2_no_adjust_is_default(self): + tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + result_tokens, result_gen = maybe_truncate( + tokens, max_gen_toks=5, max_model_len=6 + ) + assert result_tokens == [10] + assert result_gen == 5 + + def test_case2_prompt_fits_but_gen_too_large_no_adjust(self): + tokens = [1, 2, 3, 4, 5, 6, 7, 8] + result_tokens, result_gen = maybe_truncate( + tokens, max_gen_toks=3, max_model_len=8, shrink_gen_toks=False + ) + # Prompt (8) + gen (3) > max_len (8), truncate prompt to 8 - 3 = 5 + assert result_tokens == [4, 5, 6, 7, 8] + assert result_gen == 3 + + # Case 3: adjust_gen_toks=True — reduce gen toks if prompt fits + def test_case3_reduce_gen_toks(self): + tokens = [1, 2, 3, 4, 5] + result_tokens, result_gen = maybe_truncate( + tokens, max_gen_toks=10, max_model_len=8, shrink_gen_toks=True + ) + assert result_tokens == [1, 2, 3, 4, 5] + assert result_gen == 3 + + # Case 4: adjust_gen_toks=True — truncate prompt with strategy + def test_case4_truncate_left(self): + tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + result_tokens, result_gen = maybe_truncate( + tokens, + max_gen_toks=5, + max_model_len=6, + min_gen_toks=2, + side="left", + shrink_gen_toks=True, + ) + assert result_tokens == [7, 8, 9, 10] + assert result_gen == 2 + + def test_case4_truncate_right(self): + tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + result_tokens, result_gen = maybe_truncate( + tokens, + max_gen_toks=5, + max_model_len=6, + min_gen_toks=2, + side="right", + shrink_gen_toks=True, + ) + assert result_tokens == [1, 2, 3, 4] + assert result_gen == 2 + + def test_case4_truncate_middle(self): + tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + result_tokens, result_gen = maybe_truncate( + tokens, + max_gen_toks=5, + max_model_len=6, + min_gen_toks=2, + side="middle", + shrink_gen_toks=True, + ) + # max_ctx_len=4: left=2, right=2 -> [1, 2] + [9, 10] + assert result_tokens == [1, 2, 9, 10] + assert result_gen == 2 + + def test_case4_default_strategy_is_left(self): + tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + result_tokens, result_gen = maybe_truncate( + tokens, + max_gen_toks=5, + max_model_len=6, + min_gen_toks=2, + shrink_gen_toks=True, + ) + assert result_tokens == [7, 8, 9, 10] + assert result_gen == 2 + + def test_min_gen_toks_zero_reduces_to_zero(self): + # Prompt exactly fills context window, gen toks reduced to 0 + tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + result_tokens, result_gen = maybe_truncate( + tokens, + max_gen_toks=5, + max_model_len=10, + min_gen_toks=0, + shrink_gen_toks=True, + ) + assert result_tokens == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert result_gen == 0 + + def test_min_gen_toks_zero_truncates_prompt(self): + # Prompt exceeds max_len, but min_gen_toks=0 means all space goes to prompt + tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + result_tokens, result_gen = maybe_truncate( + tokens, + max_gen_toks=5, + max_model_len=8, + min_gen_toks=0, + shrink_gen_toks=True, + ) + # max_ctx_len = 8 - 0 = 8, left-truncate to 8 + assert result_tokens == [3, 4, 5, 6, 7, 8, 9, 10] + assert result_gen == 0 + + def test_raises_when_max_len_too_small(self): + tokens = [1, 2, 3, 4, 5] + with pytest.raises(ValueError): + maybe_truncate( + tokens, + max_gen_toks=5, + max_model_len=2, + min_gen_toks=3, + shrink_gen_toks=True, + ) + + +class TestNormalizeGenKwargs: + """Tests for normalize_gen_kwargs utility function.""" + + # --- until normalization --- + + def test_until_string_converted_to_list(self): + result = normalize_gen_kwargs({"until": "stop"}) + assert result["until"] == ["stop"] + + def test_until_list_passed_through(self): + result = normalize_gen_kwargs({"until": ["stop1", "stop2"]}) + assert result["until"] == ["stop1", "stop2"] + + def test_until_missing_defaults_to_empty_list(self): + result = normalize_gen_kwargs({}) + assert result["until"] == [] + + # --- max token aliases --- + + def test_max_gen_toks_used_directly(self): + result = normalize_gen_kwargs({"max_gen_toks": 100}) + assert result["max_gen_toks"] == 100 + + def test_max_new_tokens_converted(self): + result = normalize_gen_kwargs({"max_new_tokens": 150}) + assert result["max_gen_toks"] == 150 + + def test_max_tokens_converted(self): + result = normalize_gen_kwargs({"max_tokens": 200}) + assert result["max_gen_toks"] == 200 + + def test_max_completion_tokens_converted(self): + result = normalize_gen_kwargs({"max_completion_tokens": 250}) + assert result["max_gen_toks"] == 250 + + def test_default_max_gen_toks_when_none_provided(self): + result = normalize_gen_kwargs({}) + assert result["max_gen_toks"] == 256 + + def test_custom_default_max_gen_toks(self): + result = normalize_gen_kwargs({}, default_max_gen_toks=512) + assert result["max_gen_toks"] == 512 + + def test_max_token_priority_max_gen_toks_first(self): + result = normalize_gen_kwargs( + { + "max_gen_toks": 100, + "max_new_tokens": 200, + "max_tokens": 300, + } + ) + assert result["max_gen_toks"] == 100 + + def test_max_token_priority_max_new_tokens_second(self): + result = normalize_gen_kwargs( + { + "max_new_tokens": 200, + "max_tokens": 300, + "max_completion_tokens": 400, + } + ) + assert result["max_gen_toks"] == 200 + + def test_max_token_priority_max_tokens_third(self): + result = normalize_gen_kwargs( + { + "max_tokens": 300, + "max_completion_tokens": 400, + } + ) + assert result["max_gen_toks"] == 300 + + # --- do_sample and temperature interaction --- + + def test_do_sample_none_temperature_zero_sets_do_sample_false(self): + result = normalize_gen_kwargs({"temperature": 0.0}) + assert result["do_sample"] is False + + def test_do_sample_none_temperature_positive_sets_do_sample_true(self): + result = normalize_gen_kwargs({"temperature": 0.7}) + assert result["do_sample"] is True + + def test_do_sample_false_sets_temperature_zero(self): + result = normalize_gen_kwargs({"do_sample": False}) + assert result["temperature"] == 0.0 + + def test_do_sample_false_temperature_positive_forces_temperature_zero(self): + result = normalize_gen_kwargs({"do_sample": False, "temperature": 0.8}) + assert result["temperature"] == 0.0 + + def test_do_sample_true_temperature_positive_preserved(self): + result = normalize_gen_kwargs({"do_sample": True, "temperature": 0.9}) + assert result["do_sample"] is True + assert result["temperature"] == 0.9 + + def test_do_sample_true_temperature_zero_preserved(self): + result = normalize_gen_kwargs({"do_sample": True, "temperature": 0.0}) + assert result["do_sample"] is True + assert result["temperature"] == 0.0 + + # --- other behaviors --- + + def test_extra_kwargs_passed_through(self): + result = normalize_gen_kwargs( + { + "top_p": 0.95, + "top_k": 50, + "repetition_penalty": 1.1, + } + ) + assert result["top_p"] == 0.95 # type: ignore + assert result["top_k"] == 50 # type: ignore + assert result["repetition_penalty"] == 1.1 # type: ignore + + def test_original_dict_not_mutated(self): + original = {"until": "stop", "max_gen_toks": 100, "temperature": 0.5} + original_copy = original.copy() + normalize_gen_kwargs(original) + assert original == original_copy