From 06ead43db65ba7ab360d1927ac451829751b1f1b Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Wed, 21 Aug 2024 13:21:20 +0000 Subject: [PATCH 01/11] add vlmm backend --- src/lighteval/data.py | 14 ++ src/lighteval/models/model_config.py | 23 ++ src/lighteval/models/model_loader.py | 9 + src/lighteval/models/vllm_model.py | 362 +++++++++++++++++++++++++++ 4 files changed, 408 insertions(+) create mode 100644 src/lighteval/models/vllm_model.py diff --git a/src/lighteval/data.py b/src/lighteval/data.py index 22b68bd6a..9268ec9ff 100644 --- a/src/lighteval/data.py +++ b/src/lighteval/data.py @@ -161,6 +161,20 @@ def __len__(self) -> int: """ return self.split_end - self.split_start + def __iter__(self) -> Iterator[Request]: + """ + Iterator that yields the items of the dataset depending on the split we + are currently in. For instance, if we are in split 0, we will get the + items from index 0 to self.split_size, if we are in split 1, we will get + the items from index self.split_size to 2 * self.split_size, etc. Used + for dynamic batching. + + Yields: + Any: The items of the dataset. + """ + for i in range(self.split_start, self.split_end): + yield self.sorted_data[i] + def _sorting_criteria(self, request) -> int: raise NotImplementedError() diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index d1e5ab38d..a0e797f6b 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -219,6 +219,25 @@ def init_configs(self, env_config: EnvConfig): return self._init_configs(self.base_model, env_config) +@dataclass +class VLLMModelConfig: + pretrained: str + gpu_memor_utilisation: float = 0.9 + batch_size: int = 1 + revision: str = "main" + dtype: str = "float16" + tensor_parallel_size: int = 1 + data_parallel_size: int = 1 + max_model_length: int = 1024 + swap_space: int = 4 + seed: int = 1234 + trust_remote_code: bool = False + use_chat_template: bool = False + add_special_tokens: bool = True + multichoice_continuations_start_space: bool = True + subfolder: Optional[str] = None + + @dataclass class TGIModelConfig: inference_server_address: str @@ -290,6 +309,7 @@ def create_model_config( # noqa: C901 TGIModelConfig, InferenceEndpointModelConfig, DummyModelConfig, + VLLMModelConfig, ]: """ Create a model configuration based on the provided arguments. @@ -313,6 +333,9 @@ def create_model_config( # noqa: C901 if args_dict.pop("dummy", False): return DummyModelConfig(**args_dict) + if args_dict.pop("vllm", False): + return VLLMModelConfig(**args_dict) + args_dict["accelerator"] = accelerator args_dict["use_chat_template"] = args.use_chat_template args_dict["compile"] = bool(args_dict["compile"]) if "compile" in args_dict else False diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index c72d64038..417482a48 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -38,8 +38,10 @@ InferenceEndpointModelConfig, InferenceModelConfig, TGIModelConfig, + VLLMModelConfig, ) from lighteval.models.tgi_model import ModelClient +from lighteval.models.vllm_model import VLLMModel from lighteval.utils import NO_TGI_ERROR_MSG, is_accelerate_available, is_tgi_available @@ -63,6 +65,7 @@ def load_model( # noqa: C901 TGIModelConfig, InferenceEndpointModelConfig, DummyModelConfig, + VLLMModelConfig, ], env_config: EnvConfig, ) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel], ModelInfo]: @@ -94,6 +97,9 @@ def load_model( # noqa: C901 if isinstance(config, DummyModelConfig): return load_dummy_model(config=config, env_config=env_config) + if isinstance(config, VLLMModelConfig): + return load_model_with_accelerate_or_default(config=config, env_config=env_config) + def load_model_with_tgi(config: TGIModelConfig): if not is_tgi_available(): @@ -135,6 +141,9 @@ def load_model_with_accelerate_or_default( model = AdapterModel(config=config, env_config=env_config) elif isinstance(config, DeltaModelConfig): model = DeltaModel(config=config, env_config=env_config) + elif isinstance(config, VLLMModelConfig): + model = VLLMModel(config=config, env_config=env_config) + return model, ModelInfo(model_name="vllm", model_sha=str(config.seed)) else: model = BaseModel(config=config, env_config=env_config) diff --git a/src/lighteval/models/vllm_model.py b/src/lighteval/models/vllm_model.py new file mode 100644 index 000000000..59a3a97c4 --- /dev/null +++ b/src/lighteval/models/vllm_model.py @@ -0,0 +1,362 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import itertools +import os +from typing import Optional + +import ray +from more_itertools import distribute +from tqdm import tqdm +from vllm import LLM, SamplingParams +from vllm.transformers_utils.tokenizer import get_tokenizer + +from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset +from lighteval.logging.hierarchical_logger import hlog_warn +from lighteval.models.abstract_model import LightevalModel +from lighteval.models.model_config import EnvConfig, VLLMModelConfig +from lighteval.models.model_output import ( + GenerateReturn, + LoglikelihoodReturn, +) +from lighteval.models.utils import _simplify_name +from lighteval.tasks.requests import ( + GreedyUntilRequest, + LoglikelihoodRequest, +) +from lighteval.utils import as_list, is_accelerate_available + + +if is_accelerate_available(): + pass + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +STARTING_BATCH_SIZE = 512 + + +class VLLMModel(LightevalModel): + def __init__( + self, + config: VLLMModelConfig, + env_config: EnvConfig, + ): + """Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.""" + self._config = config + self._batch_size = config.batch_size + self._max_length = self._init_max_length(config.max_model_length) + self.use_chat_template = config.use_chat_template + self.data_parallel_size = int(config.data_parallel_size) + + self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False + self._tokenizer = self._create_auto_tokenizer(config, env_config) + + # If model_parallel is not set we compare the number of processes with the number of GPUs + self.model = self._create_auto_model(config, env_config) + print(f"model: {self.model}") + + # self._device = config.accelerator.device if config.accelerator is not None else "cpu" + self.multichoice_continuations_start_space = config.multichoice_continuations_start_space + + self.model_name = _simplify_name(config.pretrained) + self.model_sha = "" # config.get_model_sha() + self.precision = "float16" # _get_dtype(config.dtype, config=self._config) + + @property + def tokenizer(self): + return self._tokenizer + + @property + def add_special_tokens(self): + return self._add_special_tokens + + @property + def max_length(self) -> int: + return self._max_length + + def _create_auto_model(self, config: VLLMModelConfig, env_config: EnvConfig) -> Optional[LLM]: + """ + Creates an instance of the pretrained HF model. + + Args: + pretrained (str): The name or path of the pretrained model. + revision (str): The revision of the model. + subfolder (Optional[str], optional): The subfolder within the model. Defaults to None. + max_memory (Optional[dict], optional): The maximum memory to allocate for the model per GPU. Defaults to None. + device_map (Optional[dict], optional): The device mapping for the model. Defaults to None. + torch_dtype (Optional[Union[str, torch.dtype]], optional): The torch data type for the model. Defaults to None. + quantization_config (Optional[Union[BitsAndBytesConfig, GPTQConfig]], optional): The quantization configuration for the model. Defaults to None. + trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False. + cache_dir (str, optional): The cache directory for the model. Defaults to "/scratch". + + Returns: + transformers.PreTrainedModel: The created auto model instance. + """ + self.model_args = { + "model": config.pretrained, + "gpu_memory_utilization": float(0.8), + "revision": config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""), + "dtype": config.dtype, + "trust_remote_code": config.trust_remote_code, + "tensor_parallel_size": int(1), + "max_model_len": int(self._max_length) if self._max_length else None, + "swap_space": 4, + "seed": 1234, + } + if int(config.data_parallel_size) > 1: + self.model_args["worker_use_ray"] = True + self._batch_size = "auto" + return None + + model = LLM(**self.model_args) + return model + + def _create_auto_tokenizer(self, config: VLLMModelConfig, env_config: EnvConfig): + tokenizer = get_tokenizer( + config.pretrained, + tokenizer_mode="auto", + trust_remote_code=config.trust_remote_code, + tokenizer_revision=config.revision, + ) + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def _init_max_length(self, max_length) -> int: + """Return the maximum sequence length of the model. + NOTE: Different model configurations have different max sequence length + attribute names. + - n_positions: (CTRLConfig) + - max_position_embeddings: (BartConfig, RoFormerConfig) + - n_ctx: (GPT2Config) + NOTE: For relative position encoded models you should specify the max + sequence length of the model in the constructor via `max_length`. + + Args: + max_length (Optional[int]): The maximum length of the input sequence. If not provided, it will be determined + based on the model's configuration or tokenizer's model_max_length attribute. + + Returns: + int: Max length to use depending on the available args and config + """ + if max_length is not None: + return int(max_length) + # Try to get the sequence length from the model config. + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + + for attr in seqlen_config_attrs: + if hasattr(self._config, attr): + return getattr(self._config, attr) + + # Default max sequence length setting for when no `max_length` is provided + # or no max length config setting is found in the model or tokenizer. + return 2048 + + def greedy_until( + self, + requests: list[GreedyUntilRequest], + override_bs: Optional[int] = None, + ) -> list[GenerateReturn]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerateReturn]: list of generated responses. + """ + for request in requests: + request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] + request.tokenized_context = self.tok_encode(request.context) + + dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) + results = [] + + for _ in tqdm( + dataset.splits_start_end_iterator(), + total=dataset.num_dataset_splits, + desc="Splits", + position=0, + disable=False, # self.disable_tqdm, + ): + # For chat models, generation stops with EOS token, so we don't need to specify stop tokens + if self.use_chat_template: + stop_tokens = [] + else: + # NOTE: we are assuming all items in a batch behave similarly (same + # stop_tokens and max_tokens genrated) which is not necessarily + # the case! Because of that we only use batch size of 1 + stop_tokens = dataset[0].stop_sequence + + max_new_tokens = dataset[0].generation_size # could be none + returns_logits = dataset[0].use_logits + num_samples = dataset[0].num_samples + + context = [c.context for c in dataset] + tokenized = self.tokenizer(context, add_special_tokens=self.add_special_tokens) + + # The main question for this step is the following: + # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk + # of losing some meaning, or have some generations that are exceedingly short? + # The choice we go for here is to avoid truncating the prompt if we can, since it + # should have been managed by the prompt creator/few shot manager if requested by the user. + context_size = len(tokenized["input_ids"][0]) + if context_size > self.max_length: + hlog_warn( + f"The context size of your batch ({context_size}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in" + + str({dataset[0].task_name}) + + ". This is likely to lead to some errors." # noqa C401 + ) + # There will be truncation of at least one sample, maximum generation size will be one + max_new_tokens = 1 + else: # We can't allow generation of more than max_length + if max_new_tokens is None: # If generation size is not set, we go all the way + max_new_tokens = self.max_length - context_size + else: + max_new_tokens = min(self.max_length - context_size, max_new_tokens) + + vllm_outputs = self._generate( + inputs=tokenized["input_ids"], + max_new_tokens=max_new_tokens, + stop_tokens=stop_tokens, + returns_logits=returns_logits, + num_samples=num_samples, + ) + + for vllm_output in vllm_outputs: + output_token_ids = vllm_output.outputs[0].token_ids + logprobs = vllm_output.outputs[0].logprobs or [] + logprobs = [logprob[token_id].logprob for token_id, logprob in zip(output_token_ids, logprobs)] + result = vllm_output.outputs[0].text + input_token_ids = vllm_output.prompt_token_ids + + cur_response = GenerateReturn( + result=result, + logits=logprobs, + generated_tokens=list(output_token_ids), + input_tokens=input_token_ids, + ) + results.append(cur_response) + + return dataset.get_original_order(results) + + def _generate( + self, + inputs: list[list[int]], + max_new_tokens: Optional[int] = None, + stop_tokens: Optional[list[str]] = None, + returns_logits: Optional[bool] = False, + num_samples: int = 1, + generate: bool = True, + ) -> list[GenerateReturn]: + """Contains the actual logic of the generation.""" + if generate: + sampling_params = SamplingParams( + n=num_samples, max_tokens=max_new_tokens, stop=stop_tokens, logprobs=1 if returns_logits else 0 + ) + else: + sampling_params = SamplingParams(temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False) + + if self.data_parallel_size > 1: + # vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote + # also seems to only work with decorator and not with ray.remote() fn + # see https://github.com/vllm-project/vllm/issues/973 + # note: this has changed on 0.3.3, and it only works now if num_gpus are set. + # but then tensor_parallel breaks + @ray.remote + def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, requests): + llm = LLM(**model_args) + return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) + + # dispatch requests to all self.data_parallel_size workers, in interleaved fashion + # interleaved important to balance context lengths across workers + requests = [list(x) for x in distribute(self.data_parallel_size, inputs)] + inputs = ((self.model_args, sampling_params, req) for req in requests) + object_refs = [run_inference_one_model.remote(*x) for x in inputs] + results = ray.get(object_refs) + # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required. + ray.shutdown() + # flatten results + outputs = [ + x + for x in itertools.chain.from_iterable(itertools.zip_longest(*[list(x) for x in results])) + if x is not None + ] + else: + outputs = self.model.generate( + prompt_token_ids=inputs, + sampling_params=sampling_params, + use_tqdm=True, + ) + + return outputs + + def loglikelihood( + self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodReturn]: + for request in requests: + if request.context == "": + request.tokenized_context = [self.tokenizer.eos_token_id] + request.tokenized_continuation = self.tok_encode(request.choice) + else: + # The following line is mandatory for compatibility with the harness + request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair( + request.context, request.choice + ) + return self._loglikelihood_tokens(requests, override_bs=override_bs) + + def _loglikelihood_tokens( + self, + requests: list[LoglikelihoodRequest], + override_bs: int = -1, + return_bool_score: bool = True, + rolling: bool = False, + ) -> list[LoglikelihoodReturn]: + dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=1) + res = [] + + for _ in tqdm(dataset.splits_start_end_iterator()): + # the last token is an eos token, so we don't need to add it + inputs = [ + dataset[i].tokenized_context + dataset[i].tokenized_continuation[:-1] for i in range(len(dataset)) + ] + outputs = self._generate(inputs, generate=False) + + for output, input in zip(outputs, dataset): + continuation_logprobs = [] + for token, logprobs in zip(input.tokenized_continuation[-2::-1], output.prompt_logprobs[::-1]): + continuation_logprobs.append(logprobs[token]) + bool_score = all(logprob.rank == 1 for logprob in continuation_logprobs) + continuation_logprobs = [logprob.logprob for logprob in continuation_logprobs] + answer = LoglikelihoodReturn( + result=(sum(continuation_logprobs), bool_score if return_bool_score else None) + ) + res.append(answer) + + return dataset.get_original_order(res) + + def loglikelihood_rolling(): + pass + + def loglikelihood_single_token(): + pass From ea0cd3d2a77f44f9537d8b2022b51816446a76bf Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Mon, 26 Aug 2024 12:53:26 +0000 Subject: [PATCH 02/11] add vlmm backend --- src/lighteval/logging/hierarchical_logger.py | 6 +++--- src/lighteval/main_accelerate.py | 1 + src/lighteval/models/base_model.py | 5 +++-- src/lighteval/tasks/lighteval_task.py | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/lighteval/logging/hierarchical_logger.py b/src/lighteval/logging/hierarchical_logger.py index 9cf718f18..c5ce28778 100644 --- a/src/lighteval/logging/hierarchical_logger.py +++ b/src/lighteval/logging/hierarchical_logger.py @@ -26,6 +26,8 @@ from logging import Logger from typing import Any, Callable +from colorama import Fore, Style + from lighteval.utils import is_accelerate_available, is_nanotron_available @@ -37,10 +39,8 @@ from accelerate.logging import get_logger logger = get_logger(__name__, log_level="INFO") -else: - logger = Logger(__name__, level="INFO") -from colorama import Fore, Style +logger = Logger(__name__, level="INFO") class HierarchicalLogger: diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index 904a68322..857756c22 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -67,6 +67,7 @@ def main(args): evaluation_tracker.general_config_logger.log_args_info( args.num_fewshot_seeds, args.override_batch_size, args.max_samples, args.job_id ) + print("HELLO") if args.max_samples: hlog( diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 666c01319..2c40c5a7b 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -92,6 +92,7 @@ def __init__( hlog(f"Using Data Parallelism, putting model on device {self._device}") self.model = self.model.to(self._device) if config.compile: + hlog("Compiling the model") self.model.model.compile() self.model_name = _simplify_name(config.pretrained) @@ -536,9 +537,9 @@ def greedy_until( tokenized = self.tokenizer( context, truncation="longest_first", # we truncate to the model max length if needed - padding="longest", # we pad to the longest sequence + padding="max_length", # we pad to the longest sequence return_tensors="pt", - max_length=self.max_length - 1, # we always allow minimum one token of generation + max_length=max_context_continuation_size_allowed, # we always allow minimum one token of generation add_special_tokens=self.add_special_tokens, ).to(self.device) diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 07120b711..980f42fba 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -198,8 +198,8 @@ def __init__( # noqa: C901 for metric_name in metric_names: # If we do maj_at_ metrics, we need to use the correct number of samples - if "maj_at_" in metric_name: - self.num_samples.append(int(metric_name.replace("maj_at_", "").split("_")[0])) + if "maj@" in metric_name: + self.num_samples.append(int(metric_name.replace("maj@", "").split("_")[0])) if not isinstance(cfg.prompt_function, Callable): raise TypeError( From 0a66f408260f04de7a6e11c23cb19f0b632c4d08 Mon Sep 17 00:00:00 2001 From: Nathan Habib <30601243+NathanHB@users.noreply.github.com> Date: Mon, 2 Sep 2024 15:06:38 +0200 Subject: [PATCH 03/11] Update src/lighteval/models/model_config.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> --- src/lighteval/models/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index a0e797f6b..8bf23d3d5 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -222,7 +222,7 @@ def init_configs(self, env_config: EnvConfig): @dataclass class VLLMModelConfig: pretrained: str - gpu_memor_utilisation: float = 0.9 + gpu_memory_utilisation: float = 0.9 batch_size: int = 1 revision: str = "main" dtype: str = "float16" From 8b2392aa1c2644bcdbf5086dd37c2957a0f889b5 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Mon, 2 Sep 2024 13:28:15 +0000 Subject: [PATCH 04/11] fix from review --- pyproject.toml | 1 + src/lighteval/models/model_config.py | 8 ++++---- src/lighteval/models/vllm_model.py | 4 ++-- src/lighteval/utils.py | 7 +++++++ 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e301d7afd..c7568b2c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ nanotron = [ "tensorboardX" ] tensorboardX = ["tensorboardX"] +vllm = ["vllm", "ray"] quality = ["ruff==v0.2.2","pre-commit"] tests = ["pytest==7.4.0"] dev = ["lighteval[accelerate,quality,tests]"] diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 8bf23d3d5..20512b2bc 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -222,14 +222,14 @@ def init_configs(self, env_config: EnvConfig): @dataclass class VLLMModelConfig: pretrained: str - gpu_memory_utilisation: float = 0.9 - batch_size: int = 1 + gpu_memory_utilisation: float = 0.8 + batch_size: int = -1 revision: str = "main" - dtype: str = "float16" + dtype: str | None = None tensor_parallel_size: int = 1 data_parallel_size: int = 1 max_model_length: int = 1024 - swap_space: int = 4 + swap_space: int = 4 # CPU swap space size (GiB) per GPU. seed: int = 1234 trust_remote_code: bool = False use_chat_template: bool = False diff --git a/src/lighteval/models/vllm_model.py b/src/lighteval/models/vllm_model.py index 59a3a97c4..9d0d1ab90 100644 --- a/src/lighteval/models/vllm_model.py +++ b/src/lighteval/models/vllm_model.py @@ -43,10 +43,10 @@ GreedyUntilRequest, LoglikelihoodRequest, ) -from lighteval.utils import as_list, is_accelerate_available +from lighteval.utils import as_list, is_vllm_available -if is_accelerate_available(): +if is_vllm_available(): pass os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/src/lighteval/utils.py b/src/lighteval/utils.py index 5602a17fc..0c1c7b732 100644 --- a/src/lighteval/utils.py +++ b/src/lighteval/utils.py @@ -225,6 +225,13 @@ def is_openai_available() -> bool: NO_OPENAI_ERROR_MSG = "You are trying to use an Open AI LLM as a judge, for which you need `openai`, which is not available in your environment. Please install it using pip." +def is_vllm_available() -> bool: + return importlib.util.find_spec("vllm") is not None and importlib.util.find_spec("ray") is not None + + +NO_VLLM_ERROR_MSG = "You are trying to use an VLLM model, for which you need `vllm` and `ray`, which are not available in your environment. Please install them using pip, `pip install vllm ray`." + + def can_load_extended_tasks() -> bool: imports = [] for package in ["langdetect", "openai"]: From 9029943c7edb13ff4c17ad7240397e675205cb86 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Mon, 2 Sep 2024 13:31:41 +0000 Subject: [PATCH 05/11] fix --- src/lighteval/logging/hierarchical_logger.py | 8 +- src/lighteval/main_accelerate.py | 136 ++++++------------- 2 files changed, 43 insertions(+), 101 deletions(-) diff --git a/src/lighteval/logging/hierarchical_logger.py b/src/lighteval/logging/hierarchical_logger.py index c5ce28778..99287f750 100644 --- a/src/lighteval/logging/hierarchical_logger.py +++ b/src/lighteval/logging/hierarchical_logger.py @@ -26,9 +26,7 @@ from logging import Logger from typing import Any, Callable -from colorama import Fore, Style - -from lighteval.utils import is_accelerate_available, is_nanotron_available +from lighteval.utils.imports import is_accelerate_available, is_nanotron_available if is_nanotron_available(): @@ -39,8 +37,10 @@ from accelerate.logging import get_logger logger = get_logger(__name__, log_level="INFO") +else: + logger = Logger(__name__, level="INFO") -logger = Logger(__name__, level="INFO") +from colorama import Fore, Style class HierarchicalLogger: diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index 857756c22..22054c971 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -21,22 +21,13 @@ # SOFTWARE. import os -import random -import shutil -from contextlib import nullcontext from datetime import timedelta -import numpy as np - -from lighteval.evaluator import evaluate, make_results_table from lighteval.logging.evaluation_tracker import EvaluationTracker -from lighteval.logging.hierarchical_logger import hlog, hlog_warn, htrack, htrack_block -from lighteval.models.model_config import EnvConfig, create_model_config -from lighteval.models.model_loader import load_model -from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks -from lighteval.tasks.registry import Registry, taskinfo_selector -from lighteval.utils import is_accelerate_available, is_tgi_available -from lighteval.utils_parallelism import test_all_gather +from lighteval.logging.hierarchical_logger import hlog_warn, htrack +from lighteval.models.model_config import create_model_config +from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters +from lighteval.utils.imports import is_accelerate_available, is_tgi_available if not is_accelerate_available() and not is_tgi_available(): @@ -64,88 +55,39 @@ def main(args): public=args.public_run, token=TOKEN, ) - evaluation_tracker.general_config_logger.log_args_info( - args.num_fewshot_seeds, args.override_batch_size, args.max_samples, args.job_id + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.ACCELERATE, + env_config=env_config, + job_id=args.job_id, + dataset_loading_processes=args.dataset_loading_processes, + custom_tasks_directory=args.custom_tasks, + override_batch_size=args.override_batch_size, + num_fewshot_seeds=args.num_fewshot_seeds, + max_samples=args.max_samples, + use_chat_template=args.use_chat_template, + system_prompt=args.system_prompt, + ) + + model_config = create_model_config( + use_chat_template=args.use_chat_template, + override_batch_size=args.override_batch_size, + model_args=args.model_args, + accelerator=accelerator, + ) + + pipeline = Pipeline( + tasks=args.tasks, + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model_config=model_config, ) - print("HELLO") - - if args.max_samples: - hlog( - "WARNING: --max_samples WAS SET. THESE NUMBERS ARE ONLY PARTIAL AND SHOULD NOT BE USED FOR COMPARISON UNLESS YOU KNOW WHAT YOU ARE DOING." - ) - - with htrack_block("Test all gather"): - test_all_gather(accelerator) - - with htrack_block("Creating model configuration"): - model_config = create_model_config(args=args, accelerator=accelerator) - - with htrack_block("Model loading"): - with accelerator.main_process_first() if accelerator is not None else nullcontext(): - model, model_info = load_model(config=model_config, env_config=env_config) - evaluation_tracker.general_config_logger.log_model_info(model_info) - - with htrack_block("Tasks loading"): - with accelerator.main_process_first() if accelerator is not None else nullcontext(): - task_names_list, few_shots_dict = taskinfo_selector(args.tasks) - task_dict = Registry(cache_dir=env_config.cache_dir).get_task_dict( - task_names_list, custom_tasks=args.custom_tasks - ) - LightevalTask.load_datasets(task_dict.values(), args.dataset_loading_processes) - - evaluation_tracker.task_config_logger.log(task_dict) - - hlog("Loading documents, and requests") - requests, docs = create_requests_from_tasks( - task_dict=task_dict, - fewshot_dict=few_shots_dict, - num_fewshot_seeds=args.num_fewshot_seeds, - lm=model, - max_samples=args.max_samples, - evaluation_tracker=evaluation_tracker, - use_chat_template=args.use_chat_template, - system_prompt=args.system_prompt, - ) - - with htrack_block("Setting seeds and waiting for all processes"): - hlog(f"setting seed to {1234} for random and numpy") - random.seed(1234) - np.random.seed(1234) - if accelerator is not None: - accelerator.wait_for_everyone() - - with htrack_block("Evaluation"): - hlog(f"Evaluate on {len(task_names_list)} tasks.") - evaluation_tracker = evaluate( - lm=model, - requests_dict=requests, - docs=docs, - task_dict=task_dict, - override_bs=args.override_batch_size, - evaluation_tracker=evaluation_tracker, - ) - - if accelerator.is_main_process if accelerator is not None else nullcontext(): - with htrack_block("Compiling and saving results"): - evaluation_tracker.general_config_logger.log_end_time() - evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict, bootstrap_iters=1000) - evaluation_tracker.details_logger.aggregate() - - if args.output_dir: - evaluation_tracker.save() - - final_dict = evaluation_tracker.generate_final_dict() - - with htrack_block("Cleaninp up"): - for weights in ["delta", "adapter"]: - try: - tmp_weights_dir = f"{evaluation_tracker.general_config_logger.model_name}-{weights}-applied" - hlog(f"Removing {tmp_weights_dir}") - shutil.rmtree(tmp_weights_dir) - except OSError: - pass - - print(make_results_table(final_dict)) - - model.cleanup() - return final_dict + + pipeline.evaluate() + + pipeline.show_results() + + results = pipeline.get_results() + + pipeline.save_and_push_results() + + return results From 5c9029c4ba7370bc68e803418286d2e5b9b19145 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Mon, 2 Sep 2024 13:35:05 +0000 Subject: [PATCH 06/11] fix --- src/lighteval/models/model_loader.py | 10 +++++++++- src/lighteval/models/vllm_model.py | 5 +---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 417482a48..f8b968a1e 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -42,7 +42,13 @@ ) from lighteval.models.tgi_model import ModelClient from lighteval.models.vllm_model import VLLMModel -from lighteval.utils import NO_TGI_ERROR_MSG, is_accelerate_available, is_tgi_available +from lighteval.utils import ( + NO_TGI_ERROR_MSG, + NO_VLLM_ERROR_MSG, + is_accelerate_available, + is_tgi_available, + is_vllm_available, +) if is_accelerate_available(): @@ -142,6 +148,8 @@ def load_model_with_accelerate_or_default( elif isinstance(config, DeltaModelConfig): model = DeltaModel(config=config, env_config=env_config) elif isinstance(config, VLLMModelConfig): + if not is_vllm_available(): + raise ImportError(NO_VLLM_ERROR_MSG) model = VLLMModel(config=config, env_config=env_config) return model, ModelInfo(model_name="vllm", model_sha=str(config.seed)) else: diff --git a/src/lighteval/models/vllm_model.py b/src/lighteval/models/vllm_model.py index 9d0d1ab90..dc8c712a4 100644 --- a/src/lighteval/models/vllm_model.py +++ b/src/lighteval/models/vllm_model.py @@ -43,12 +43,9 @@ GreedyUntilRequest, LoglikelihoodRequest, ) -from lighteval.utils import as_list, is_vllm_available +from lighteval.utils import as_list -if is_vllm_available(): - pass - os.environ["TOKENIZERS_PARALLELISM"] = "false" STARTING_BATCH_SIZE = 512 From 97632d9a77a0b0efb5fad3f7e79a859452ff5525 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Mon, 2 Sep 2024 14:36:36 +0000 Subject: [PATCH 07/11] fix --- src/lighteval/models/model_config.py | 3 +++ src/lighteval/models/vllm_model.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index cddb95d3b..51185912d 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -333,6 +333,9 @@ def create_model_config( # noqa: C901 if model_args.pop("dummy", False): return DummyModelConfig(**model_args) + if model_args.pop("vllm", False): + return VLLMModelConfig(**model_args) + model_args["accelerator"] = accelerator model_args["use_chat_template"] = use_chat_template model_args["compile"] = bool(model_args["compile"]) if "compile" in model_args else False diff --git a/src/lighteval/models/vllm_model.py b/src/lighteval/models/vllm_model.py index 56745b2e0..93fde27ae 100644 --- a/src/lighteval/models/vllm_model.py +++ b/src/lighteval/models/vllm_model.py @@ -38,7 +38,7 @@ GenerativeResponse, LoglikelihoodResponse, ) -from lighteval.models.utils import _simplify_name +from lighteval.models.utils import _get_dtype, _simplify_name from lighteval.tasks.requests import ( GreedyUntilRequest, LoglikelihoodRequest, @@ -76,9 +76,9 @@ def __init__( self.model_name = _simplify_name(config.pretrained) self.model_sha = "" # config.get_model_sha() - self.precision = "float16" # _get_dtype(config.dtype, config=self._config) + self.precision = _get_dtype(config.dtype, config=self._config) - self.model_info = ModelInfo() + self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha) @property def tokenizer(self): From fe8da513987362d681c6b5113fb6e2f5342c7d13 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Mon, 2 Sep 2024 15:59:38 +0000 Subject: [PATCH 08/11] fix --- src/lighteval/models/vllm_model.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/lighteval/models/vllm_model.py b/src/lighteval/models/vllm_model.py index 93fde27ae..5f51a10d6 100644 --- a/src/lighteval/models/vllm_model.py +++ b/src/lighteval/models/vllm_model.py @@ -24,11 +24,8 @@ import os from typing import Optional -import ray from more_itertools import distribute from tqdm import tqdm -from vllm import LLM, SamplingParams -from vllm.transformers_utils.tokenizer import get_tokenizer from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset from lighteval.logging.hierarchical_logger import hlog_warn @@ -43,9 +40,20 @@ GreedyUntilRequest, LoglikelihoodRequest, ) +from lighteval.utils.imports import is_vllm_available from lighteval.utils.utils import EnvConfig, as_list +if is_vllm_available(): + import ray + from vllm import LLM, SamplingParams + from vllm.transformers_utils.tokenizer import get_tokenizer +else: + LLM = None + SamplingParams = None + get_tokenizer = None + ray = None + os.environ["TOKENIZERS_PARALLELISM"] = "false" STARTING_BATCH_SIZE = 512 @@ -242,11 +250,12 @@ def greedy_until( num_samples=num_samples, ) + print(f"{len(vllm_outputs)} vllm_outputs") for vllm_output in vllm_outputs: - output_token_ids = vllm_output.outputs[0].token_ids - logprobs = vllm_output.outputs[0].logprobs or [] - logprobs = [logprob[token_id].logprob for token_id, logprob in zip(output_token_ids, logprobs)] - result = vllm_output.outputs[0].text + output_token_ids = [outputs.token_ids for outputs in vllm_output.outputs] + logprobs = [output.logprobs for output in vllm_output.outputs] or [] + logprobs = [logprob[token_id].logprob for token_id, logprob in zip(output_token_ids[0], logprobs[0])] + result = [output.text for output in vllm_output.outputs] input_token_ids = vllm_output.prompt_token_ids cur_response = GenerativeResponse( From b52239db874b327ac03ad1fe658d6af58e10cab9 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Mon, 2 Sep 2024 16:24:03 +0000 Subject: [PATCH 09/11] fix --- src/lighteval/models/base_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index a0750c786..a975d7cf1 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -574,7 +574,10 @@ def greedy_until( if max_new_tokens is None: # If generation size is not set, we go all the way max_new_tokens = self.max_length - context_size else: + print(self.max_length, context_size, max_new_tokens) max_new_tokens = min(self.max_length - context_size, max_new_tokens) + if max_new_tokens < 1: + max_new_tokens = 1 prepared_batch = Batch( input_ids=tokenized["input_ids"], From ed7b35d9b450a25c047971063e95646b0401f9be Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Tue, 3 Sep 2024 11:36:15 +0000 Subject: [PATCH 10/11] fix --- pyproject.toml | 2 +- src/lighteval/models/vllm_model.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7568b2c4..d9a9ed116 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ nanotron = [ "tensorboardX" ] tensorboardX = ["tensorboardX"] -vllm = ["vllm", "ray"] +vllm = ["vllm", "ray", "more_itertools"] quality = ["ruff==v0.2.2","pre-commit"] tests = ["pytest==7.4.0"] dev = ["lighteval[accelerate,quality,tests]"] diff --git a/src/lighteval/models/vllm_model.py b/src/lighteval/models/vllm_model.py index 5f51a10d6..2a298bfa8 100644 --- a/src/lighteval/models/vllm_model.py +++ b/src/lighteval/models/vllm_model.py @@ -24,7 +24,6 @@ import os from typing import Optional -from more_itertools import distribute from tqdm import tqdm from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset @@ -46,6 +45,7 @@ if is_vllm_available(): import ray + from more_itertools import distribute from vllm import LLM, SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer else: @@ -53,6 +53,7 @@ SamplingParams = None get_tokenizer = None ray = None + distribute = None os.environ["TOKENIZERS_PARALLELISM"] = "false" From b756f4c078c21fade72234c330f406f36b7069c2 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Tue, 3 Sep 2024 13:32:38 +0000 Subject: [PATCH 11/11] remove unescerary print --- src/lighteval/models/vllm_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lighteval/models/vllm_model.py b/src/lighteval/models/vllm_model.py index 2a298bfa8..aec4d441b 100644 --- a/src/lighteval/models/vllm_model.py +++ b/src/lighteval/models/vllm_model.py @@ -78,7 +78,6 @@ def __init__( # If model_parallel is not set we compare the number of processes with the number of GPUs self.model = self._create_auto_model(config, env_config) - print(f"model: {self.model}") # self._device = config.accelerator.device if config.accelerator is not None else "cpu" self.multichoice_continuations_start_space = config.multichoice_continuations_start_space