Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lm_eval/_cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def key_val_to_dict(args: str) -> dict[str, Any]:
if not args:
return res

for k, v in (item.split("=",1) for item in args.split(",")):
for k, v in (item.split("=", 1) for item in args.split(",")):
v = handle_cli_value_string(v)
if k in res:
eval_logger.warning(f"Overwriting key '{k}': {res[k]!r} -> {v!r}")
Expand Down
31 changes: 15 additions & 16 deletions lm_eval/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import json
import logging
import os
from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar

from tqdm import tqdm

Expand Down Expand Up @@ -31,7 +32,7 @@ def __init__(self) -> None:
# set rank and world size to a single process, by default.
self._rank = 0
self._world_size = 1
self.cache_hook: "CacheHook" = CacheHook(None)
self.cache_hook: CacheHook = CacheHook(None)

@abc.abstractmethod
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
Expand Down Expand Up @@ -68,7 +69,8 @@ def loglikelihood_rolling(self, requests) -> list[float]:
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:

Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: BOS/EOS
Max context length: 4
Expand Down Expand Up @@ -137,7 +139,7 @@ def apply_chat_template(

@classmethod
def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
cls: type[T], arg_string: str, additional_config: dict | None = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
Expand All @@ -156,7 +158,7 @@ def create_from_arg_string(

@classmethod
def create_from_arg_obj(
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
cls: type[T], arg_dict: dict, additional_config: dict | None = None
) -> T:
"""
Creates an instance of the LM class using the given arg_obj
Expand Down Expand Up @@ -201,7 +203,7 @@ def tokenizer_name(self) -> str:
"To use this model with chat templates, please implement the 'tokenizer_name' property."
)

def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
"""Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default.
Expand All @@ -222,7 +224,7 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class CacheHook:
def __init__(self, cachinglm: Optional["CachingLM"]) -> None:
if cachinglm is None:
self.dbdict: Optional["SqliteDict"] = None
self.dbdict: SqliteDict | None = None
return

self.dbdict = cachinglm.dbdict
Expand Down Expand Up @@ -292,15 +294,12 @@ def _fn(requests: list["Instance"]) -> list["Instance"]:
eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
)
if remaining_reqs:
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
else:
rem_res = []
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs) if remaining_reqs else []

# stick the new ones back into the list and also cache any of the new ones
resptr = 0
for req, r in zip(remaining_reqs, rem_res):
for req, r in zip(remaining_reqs, rem_res, strict=True):
while res[resptr] is not None:
resptr += 1

Expand Down Expand Up @@ -340,7 +339,7 @@ def prefix_token_id(self):

@abc.abstractmethod
def tok_encode(
self, string: str, add_special_tokens: Optional[bool] = None, **kwargs
self, string: str, add_special_tokens: bool | None = None, **kwargs
) -> list[int]:
"""
Tokenize a string using the model's tokenizer and return a list of token IDs.
Expand All @@ -351,7 +350,7 @@ def tok_encode(

@abc.abstractmethod
def _loglikelihood_tokens(
self, requests: list["Instance"], **kwargs
self, requests: list[tuple[tuple[str, str], list[int], list[int]]], **kwargs
) -> list[tuple[float, bool]]:
pass

Expand Down Expand Up @@ -462,7 +461,7 @@ def loglikelihood_rolling(
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
pass

def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
"""
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
Expand Down
12 changes: 3 additions & 9 deletions lm_eval/config/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from inspect import getsource
from typing import TYPE_CHECKING, Any

from lm_eval.defaults import default_gen_kwargs


if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -145,15 +147,7 @@ def __post_init__(self) -> None:
else:
if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": (
None
if self.fewshot_delimiter is None
else [self.fewshot_delimiter]
),
"do_sample": False,
"temperature": 0,
}
self.generation_kwargs = default_gen_kwargs(self.fewshot_delimiter)
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
Expand Down
23 changes: 23 additions & 0 deletions lm_eval/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Any


DEFAULT_MAX_LENGTH = 2048
DEFAULT_MAX_GEN_TOKS = 256
DEFAULT_RANDOM_SEED = 0
DEFAULT_OTHER_SEED = 1234


def default_gen_kwargs(
until: str | list[str] | None, max_gen_toks: int = DEFAULT_MAX_GEN_TOKS
) -> dict[str, Any]:
"""Returns default generation kwargs for LM evaluation."""
_gen = {
"temperature": 0.0,
"do_sample": False,
"max_gen_toks": max_gen_toks,
}
if until is not None:
_gen["until"] = [until] if isinstance(until, str) else until
else:
_gen["until"] = []
return _gen
9 changes: 5 additions & 4 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import lm_eval.api.registry
import lm_eval.api.task
from lm_eval.caching.cache import delete_cache
from lm_eval.defaults import DEFAULT_OTHER_SEED, DEFAULT_RANDOM_SEED
from lm_eval.evaluator_utils import (
consolidate_group_results,
consolidate_results,
Expand Down Expand Up @@ -75,10 +76,10 @@ def simple_evaluate(
task_manager: TaskManager | None = None,
verbosity=None,
predict_only: bool = False,
random_seed: int = 0,
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
random_seed: int = DEFAULT_RANDOM_SEED,
numpy_random_seed: int = DEFAULT_OTHER_SEED,
torch_random_seed: int = DEFAULT_OTHER_SEED,
fewshot_random_seed: int = DEFAULT_OTHER_SEED,
confirm_run_unsafe_code: bool = False,
metadata: dict | None = None,
):
Expand Down
42 changes: 22 additions & 20 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import copy
import logging
import os
from datetime import timedelta
Expand Down Expand Up @@ -35,6 +34,7 @@
configure_pad_token,
handle_stop_sequences,
has_bos_prefix,
normalize_gen_kwargs,
postprocess_generated_text,
)
from lm_eval.models.utils_hf import (
Expand Down Expand Up @@ -900,6 +900,7 @@ def tok_batch_encode(
else:
add_special_tokens = {}

assert self.tokenizer, "Tokenizer shoukd be initialized at this point"
encoding = self.tokenizer(
strings,
truncation=truncation,
Expand Down Expand Up @@ -982,11 +983,11 @@ def _model_generate(
do_sample = generation_kwargs.get("do_sample")

# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
if (temp := generation_kwargs.get("temperature")) == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False

if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
if do_sample is False and temp == 0.0:
generation_kwargs.pop("temperature", None)
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, context.shape[1], context.shape[0]
Expand Down Expand Up @@ -1136,7 +1137,7 @@ def _loglikelihood_tokens(
requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False,
override_bs: int | None = None,
) -> list[tuple[float, bool]]:
) -> list[tuple[float, bool]]: # type: ignore[invalid-method-override]
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []

Expand Down Expand Up @@ -1275,11 +1276,13 @@ def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]):

# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
assert padding_len_inp, "padding_len_inp should be set by now"
if self.backend == "causal":
batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
elif self.backend == "seq2seq":
assert padding_len_cont, "padding_len_cont should be set by now"
# TODO: left-pad encoder inps and mask?
batched_inps = pad_and_concat(
padding_len_inp, inps
Expand Down Expand Up @@ -1425,19 +1428,13 @@ def _collate(req: tuple[str, dict]):
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise TypeError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs:
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
assert isinstance(gen_kwargs, dict), (
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
kwargs = normalize_gen_kwargs(gen_kwargs, self.max_gen_toks)
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
max_gen_toks = kwargs.pop("max_gen_toks")

# set the max length in tokens of inputs ("context_enc")
if self.backend == "causal":
Expand All @@ -1459,14 +1456,19 @@ def _collate(req: tuple[str, dict]):
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)

if "max_length" not in kwargs:
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
# max_length = context + generation tokens
if "max_length" in kwargs:
eval_logger.warning(
"`max_length` in generation kwargs. Please use `max_gen_toks` instead."
)
max_length = kwargs.pop("max_length", context_enc.shape[1] + max_gen_toks) # type: ignore

# perform batched generation
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
stop=until,
max_length=max_length,
**kwargs,
)

Expand Down
Loading