diff --git a/.gitignore b/.gitignore index 7b47d71d4..77e347cf5 100644 --- a/.gitignore +++ b/.gitignore @@ -164,21 +164,12 @@ tests/.data tests/data # outputs folder -examples/*/outputs -examples/*/NeMo_experiments -examples/*/nemo_experiments -examples/*/.hydra -examples/*/wandb -examples/*/data wandb dump.py docs/sources/source/test_build/ # Checkpoints, config files and temporary files created in tutorials. -examples/neural_graphs/*.chkpt -examples/neural_graphs/*.yml - .hydra/ nemo_experiments/ @@ -186,7 +177,6 @@ nemo_experiments/ tmp.py -examples benchmark_output prod_env diff --git a/README.md b/README.md index 808b71505..2ed508293 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,6 @@ We're releasing it with the community in the spirit of building in the open. Note that it is still very much early so don't expect 100% stability ^^' In case of problems or question, feel free to open an issue! -## News -- **Feb 08, 2024**: Release of `lighteval` - ## Installation Clone the repo: @@ -98,7 +95,7 @@ Here, `--tasks` refers to either a _comma-separated_ list of supported tasks fro suite|task|num_few_shot|{0 or 1 to automatically reduce `num_few_shot` if prompt is too long} ``` -or a file path like [`tasks_examples/recommended_set.txt`](./tasks_examples/recommended_set.txt) which specifies multiple task configurations. For example, to evaluate GPT-2 on the Truthful QA benchmark run: +or a file path like [`examples/tasks/recommended_set.txt`](./examples/tasks/recommended_set.txt) which specifies multiple task configurations. For example, to evaluate GPT-2 on the Truthful QA benchmark run: ```shell accelerate launch --multi_gpu --num_processes=8 run_evals_accelerate.py \ @@ -118,7 +115,20 @@ accelerate launch --multi_gpu --num_processes=8 run_evals_accelerate.py \ --output_dir="./evals/" ``` -See the [`tasks_examples/recommended_set.txt`](./tasks_examples/recommended_set.txt) file for a list of recommended task configurations. +See the [`examples/tasks/recommended_set.txt`](./examples/tasks/recommended_set.txt) file for a list of recommended task configurations. + +### Evaluating a model with a complex configuration + +If you want to evaluate a model by spinning up inference endpoints, or use adapter/delta weights, or more complex configuration options, you can load models using a configuration file. This is done as follows: + +```shell +accelerate launch --multi_gpu --num_processes= run_evals_accelerate.py \ + --model_config_path="" \ + --tasks \ + --output_dir output_dir +``` + +Examples of possible configuration files are provided in `examples/model_configs`. ### Evaluating a large model with pipeline parallelism @@ -127,15 +137,13 @@ To evaluate models larger that ~40B parameters in 16-bit precision, you will nee ```shell # PP=2, DP=4 - good for models < 70B params accelerate launch --multi_gpu --num_processes=4 run_evals_accelerate.py \ - --model_args="pretrained=" \ - --model_parallel \ + --model_args="pretrained=,model_parallel=True" \ --tasks \ --output_dir output_dir # PP=4, DP=2 - good for huge models >= 70B params accelerate launch --multi_gpu --num_processes=2 run_evals_accelerate.py \ - --model_args="pretrained=" \ - --model_parallel \ + --model_args="pretrained=,model_parallel=True" \ --tasks \ --output_dir output_dir ``` @@ -147,7 +155,7 @@ To evaluate a model on all the benchmarks of the [Open LLM Leaderboard](https:// ```shell accelerate launch --multi_gpu --num_processes=8 run_evals_accelerate.py \ --model_args "pretrained=" \ - --tasks tasks_examples/open_llm_leaderboard_tasks.txt \ + --tasks examples/tasks/open_llm_leaderboard_tasks.txt \ --override_batch_size 1 \ --output_dir="./evals/" ``` @@ -220,7 +228,7 @@ However, we are very grateful to the Harness and HELM teams for their continued - [metrics](https://github.com/huggingface/lighteval/tree/main/src/lighteval/metrics): All the available metrics you can use. They are described in metrics, and divided between sample metrics (applied at the sample level, such as a prediction accuracy) and corpus metrics (applied over the whole corpus). You'll also find available normalisation functions. - [models](https://github.com/huggingface/lighteval/tree/main/src/lighteval/models): Possible models to use. We cover transformers (base_model), with adapter or delta weights, as well as TGI models locally deployed (it's likely the code here is out of date though), and brrr/nanotron models. - [tasks](https://github.com/huggingface/lighteval/tree/main/src/lighteval/tasks): Available tasks. The complete list is in `tasks_table.jsonl`, and you'll find all the prompts in `tasks_prompt_formatting.py`. Popular tasks requiring custom logic are exceptionally added in the [extended tasks](https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/extended). -- [tasks_examples](https://github.com/huggingface/lighteval/tree/main/tasks_examples) contains a list of available tasks you can launch. We advise using tasks in the `recommended_set`, as it's possible that some of the other tasks need double checking. +- [examples/tasks](https://github.com/huggingface/lighteval/tree/main/examples/tasks) contains a list of available tasks you can launch. We advise using tasks in the `recommended_set`, as it's possible that some of the other tasks need double checking. - [tests](https://github.com/huggingface/lighteval/tree/main/tests) contains our test suite, that we run at each PR to prevent regressions in metrics/prompts/tasks, for a subset of important tasks. ## Customisation @@ -291,7 +299,7 @@ if __name__ == "__main__": You can then give your custom metric to lighteval by using `--custom-tasks path_to_your_file` when launching it. -To see an example of a custom metric added along with a custom task, look at `tasks_examples/custom_tasks_with_custom_metrics/ifeval/ifeval.py`. +To see an example of a custom metric added along with a custom task, look at `examples/tasks/custom_tasks_with_custom_metrics/ifeval/ifeval.py`. ## Available metrics ### Metrics for multiple choice tasks @@ -414,7 +422,7 @@ source /activate #or conda activate yourenv cd /lighteval export CUDA_LAUNCH_BLOCKING=1 -srun accelerate launch --multi_gpu --num_processes=8 run_evals_accelerate.py --model_args "pretrained=your model name" --tasks tasks_examples/open_llm_leaderboard_tasks.txt --override_batch_size 1 --save_details --output_dir=your output dir +srun accelerate launch --multi_gpu --num_processes=8 run_evals_accelerate.py --model_args "pretrained=your model name" --tasks examples/tasks/open_llm_leaderboard_tasks.txt --override_batch_size 1 --save_details --output_dir=your output dir ``` ## Releases diff --git a/examples/model_configs/base_model.yaml b/examples/model_configs/base_model.yaml new file mode 100644 index 000000000..43d868876 --- /dev/null +++ b/examples/model_configs/base_model.yaml @@ -0,0 +1,12 @@ +model: + type: "base" # can be base, tgi, or endpoint + base_params: + model_args: "pretrained=HuggingFaceH4/zephyr-7b-beta,revision=main" # pretrained=model_name,trust_remote_code=boolean,revision=revision_to_use,model_parallel=True ... + dtype: "bfloat16" + merged_weights: # Ignore this section if you are not using PEFT models + delta_weights: false # set to True of your model should be merged with a base model, also need to provide the base model name + adapter_weights: false # set to True of your model has been trained with peft, also need to provide the base model name + base_model: null # path to the base_model + generation: + multichoice_continuations_start_space: false # Whether to force multiple choice continuations to start with a space + no_multichoice_continuations_start_space: false # Whether to force multiple choice continuations to not start with a space diff --git a/examples/model_configs/endpoint_model.yaml b/examples/model_configs/endpoint_model.yaml new file mode 100644 index 000000000..2f050bc65 --- /dev/null +++ b/examples/model_configs/endpoint_model.yaml @@ -0,0 +1,18 @@ +model: + type: "endpoint" # can be base, tgi, or endpoint + base_params: + endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters + model: "meta-llama/Llama-2-7b-hf" + revision: "main" + dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16" + reuse_existing: false # if true, ignore all params in instance + instance: + accelerator: "gpu" + region: "eu-west-1" + vendor: "aws" + instance_size: "medium" + instance_type: "g5.2xlarge" + framework: "pytorch" + endpoint_type: "protected" + generation: + add_special_tokens: true diff --git a/examples/model_configs/tgi_model.yaml b/examples/model_configs/tgi_model.yaml new file mode 100644 index 000000000..4cfb80860 --- /dev/null +++ b/examples/model_configs/tgi_model.yaml @@ -0,0 +1,5 @@ +model: + type: "tgi" # can be base, tgi, or endpoint + instance: + inference_server_address: "" + inference_server_auth: null diff --git a/tasks_examples/custom_tasks/custom_evaluation_tasks.py b/examples/nanotron/custom_evaluation_tasks.py similarity index 100% rename from tasks_examples/custom_tasks/custom_evaluation_tasks.py rename to examples/nanotron/custom_evaluation_tasks.py diff --git a/tasks_examples/custom_tasks/custom_task.py b/examples/nanotron/custom_task.py similarity index 100% rename from tasks_examples/custom_tasks/custom_task.py rename to examples/nanotron/custom_task.py diff --git a/tasks_examples/custom_tasks/lighteval_config_override_template.yaml b/examples/nanotron/lighteval_config_override_template.yaml similarity index 100% rename from tasks_examples/custom_tasks/lighteval_config_override_template.yaml rename to examples/nanotron/lighteval_config_override_template.yaml diff --git a/tasks_examples/OALL_tasks.txt b/examples/tasks/OALL_tasks.txt similarity index 100% rename from tasks_examples/OALL_tasks.txt rename to examples/tasks/OALL_tasks.txt diff --git a/tasks_examples/all_arabic_tasks.txt b/examples/tasks/all_arabic_tasks.txt similarity index 100% rename from tasks_examples/all_arabic_tasks.txt rename to examples/tasks/all_arabic_tasks.txt diff --git a/tasks_examples/all_tasks.txt b/examples/tasks/all_tasks.txt similarity index 100% rename from tasks_examples/all_tasks.txt rename to examples/tasks/all_tasks.txt diff --git a/tasks_examples/bbh.txt b/examples/tasks/bbh.txt similarity index 100% rename from tasks_examples/bbh.txt rename to examples/tasks/bbh.txt diff --git a/tasks_examples/open_llm_leaderboard_tasks.txt b/examples/tasks/open_llm_leaderboard_tasks.txt similarity index 100% rename from tasks_examples/open_llm_leaderboard_tasks.txt rename to examples/tasks/open_llm_leaderboard_tasks.txt diff --git a/tasks_examples/recommended_set.txt b/examples/tasks/recommended_set.txt similarity index 100% rename from tasks_examples/recommended_set.txt rename to examples/tasks/recommended_set.txt diff --git a/run_evals_accelerate.py b/run_evals_accelerate.py index 76e4bb653..a743cb496 100644 --- a/run_evals_accelerate.py +++ b/run_evals_accelerate.py @@ -34,54 +34,16 @@ def get_parser(): group = parser.add_mutually_exclusive_group(required=True) task_type_group = parser.add_mutually_exclusive_group(required=True) - # Model type 1) Base model - weight_type_group = parser.add_mutually_exclusive_group() - weight_type_group.add_argument( - "--delta_weights", - action="store_true", - default=False, - help="set to True of your model should be merged with a base model, also need to provide the base model name", - ) - weight_type_group.add_argument( - "--adapter_weights", - action="store_true", - default=False, - help="set to True of your model has been trained with peft, also need to provide the base model name", - ) - parser.add_argument( - "--base_model", type=str, default=None, help="name of the base model to be used for delta or adapter weights" - ) - + # Model type: either use a config file or simply the model name + task_type_group.add_argument("--model_config_path") task_type_group.add_argument("--model_args") - parser.add_argument("--model_dtype", type=str, default=None) - parser.add_argument( - "--multichoice_continuations_start_space", - action="store_true", - help="Whether to force multiple choice continuations to start with a space", - ) - parser.add_argument( - "--no_multichoice_continuations_start_space", - action="store_true", - help="Whether to force multiple choice continuations to not start with a space", - ) - parser.add_argument("--use_chat_template", default=False, action="store_true") - parser.add_argument("--system_prompt", type=str, default=None) - # Model type 2) TGI - task_type_group.add_argument("--inference_server_address", type=str) - parser.add_argument("--inference_server_auth", type=str, default=None) - # Model type 3) Inference endpoints - task_type_group.add_argument("--endpoint_model_name", type=str) - parser.add_argument("--revision", type=str) - parser.add_argument("--accelerator", type=str, default=None) - parser.add_argument("--vendor", type=str, default=None) - parser.add_argument("--region", type=str, default=None) - parser.add_argument("--instance_size", type=str, default=None) - parser.add_argument("--instance_type", type=str, default=None) - parser.add_argument("--reuse_existing", default=False, action="store_true") + # Debug parser.add_argument("--max_samples", type=int, default=None) + parser.add_argument("--override_batch_size", type=int, default=-1) parser.add_argument("--job_id", type=str, help="Optional Job ID for future reference", default="") # Saving + parser.add_argument("--output_dir", required=True) parser.add_argument("--push_results_to_hub", default=False, action="store_true") parser.add_argument("--save_details", action="store_true") parser.add_argument("--push_details_to_hub", default=False, action="store_true") @@ -95,8 +57,8 @@ def get_parser(): help="Hub organisation where you want to store the results. Your current token must have write access to it", ) # Common parameters - parser.add_argument("--output_dir", required=True) - parser.add_argument("--override_batch_size", type=int, default=-1) + parser.add_argument("--use_chat_template", default=False, action="store_true") + parser.add_argument("--system_prompt", type=str, default=None) parser.add_argument("--dataset_loading_processes", type=int, default=1) parser.add_argument( "--custom_tasks", diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index 06aaa1947..d2ffbbe3b 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -131,18 +131,15 @@ def main(args): final_dict = evaluation_tracker.generate_final_dict() with htrack_block("Cleaninp up"): - if args.delta_weights: - tmp_weights_dir = f"{evaluation_tracker.general_config_logger.model_name}-delta-applied" - hlog(f"Removing {tmp_weights_dir}") - shutil.rmtree(tmp_weights_dir) - if args.adapter_weights: - tmp_weights_dir = f"{evaluation_tracker.general_config_logger.model_name}-adapter-applied" - hlog(f"Removing {tmp_weights_dir}") - shutil.rmtree(tmp_weights_dir) + for weights in ["delta", "adapter"]: + try: + tmp_weights_dir = f"{evaluation_tracker.general_config_logger.model_name}-{weights}-applied" + hlog(f"Removing {tmp_weights_dir}") + shutil.rmtree(tmp_weights_dir) + except OSError: + pass print(make_results_table(final_dict)) - if not args.reuse_existing: - model.cleanup() - + model.cleanup() return final_dict diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 1c2d0ce00..5dbaa750a 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -42,7 +42,7 @@ LoglikelihoodReturn, LoglikelihoodSingleTokenReturn, ) -from lighteval.models.utils import _get_dtype, _get_precision, _simplify_name, batched +from lighteval.models.utils import _get_dtype, _simplify_name, batched from lighteval.tasks.requests import ( GreedyUntilMultiTurnRequest, GreedyUntilRequest, @@ -88,7 +88,7 @@ def __init__( self.multichoice_continuations_start_space = config.multichoice_continuations_start_space # We are in DP (and launch the script with `accelerate launch`) - if not config.model_parallel and not config.load_in_4bit and not config.load_in_8bit: + if not config.model_parallel and config.quantization_config is None: # might need to use accelerate instead # self.model = config.accelerator.prepare(self.model) hlog(f"Using Data Parallelism, putting model on device {self._device}") @@ -97,7 +97,7 @@ def __init__( self.model_name = _simplify_name(config.pretrained) self.model_sha = config.get_model_sha() - self.precision = _get_precision(config, model_auto_config=self._config) + self.precision = _get_dtype(config.dtype, config=self._config) @property def tokenizer(self): diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index 38a09502d..6bd5cea79 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -63,6 +63,7 @@ class InferenceEndpointModel(LightevalModel): def __init__( self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig ) -> None: + self.reuse_existing = getattr(config, "should_reuse_existing", True) if isinstance(config, InferenceEndpointModelConfig): if config.should_reuse_existing: self.endpoint = get_inference_endpoint(name=config.name, token=env_config.token) @@ -130,7 +131,7 @@ def disable_tqdm(self) -> bool: False # no accelerator = this is the main process def cleanup(self): - if self.endpoint is not None: + if self.endpoint is not None and not self.reuse_existing: self.endpoint.delete() hlog_warn( "You deleted your endpoint after using it. You'll need to create it again if you need to reuse it." diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 19c8bdaf8..03755b1d5 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -25,6 +25,7 @@ from typing import Dict, Optional, Union import torch +import yaml from transformers import AutoConfig, BitsAndBytesConfig, GPTQConfig, PretrainedConfig from lighteval.logging.hierarchical_logger import hlog @@ -57,27 +58,6 @@ class EnvConfig: cache_dir: str = None token: str = None - """Args: - pretrained (str): - HuggingFace Hub model ID name or the path to a pre-trained - model to load. This is effectively the `pretrained_model_name_or_path` - argument of `from_pretrained` in the HuggingFace `transformers` API. - add_special_tokens (bool, optional, defaults to True): - Whether to add special tokens to the input sequences. If `None`, the - default value will be set to `True` for seq2seq models (e.g. T5) and - `False` for causal models. - > Large model loading `accelerate` arguments - model_parallel (bool, optional, defaults to False): - True/False: force to uses or not the `accelerate` library to load a large - model across multiple devices. - Default: None which correspond to comparing the number of process with - the number of GPUs. If it's smaller => model-parallelism, else not. - dtype (Union[str, torch.dtype], optional, defaults to None):): - Converts the model weights to `dtype`, if specified. Strings get - converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`). - Use `dtype="auto"` to derive the type from the model's weights. - """ - @dataclass class BaseModelConfig: @@ -85,10 +65,10 @@ class BaseModelConfig: Base configuration class for models. Attributes: - pretrained (str): HuggingFace Hub model ID name or the path to a - pre-trained model to load. This is effectively the - `pretrained_model_name_or_path` argument of `from_pretrained` in the - HuggingFace `transformers` API. + pretrained (str): + HuggingFace Hub model ID name or the path to a pre-trained + model to load. This is effectively the `pretrained_model_name_or_path` + argument of `from_pretrained` in the HuggingFace `transformers` API. accelerator (Accelerator): accelerator to use for model training. tokenizer (Optional[str]): HuggingFace Hub tokenizer ID that will be used for tokenization. @@ -104,13 +84,18 @@ class BaseModelConfig: add_special_tokens (bool, optional, defaults to True): Whether to add special tokens to the input sequences. If `None`, the default value will be set to `True` for seq2seq models (e.g. T5) and `False` for causal models. - model_parallel (Optional[bool]): Whether to use model parallelism. - dtype (Optional[Union[str, torch.dtype]]): data type of the model. + model_parallel (bool, optional, defaults to False): + True/False: force to uses or not the `accelerate` library to load a large + model across multiple devices. + Default: None which correspond to comparing the number of process with + the number of GPUs. If it's smaller => model-parallelism, else not. + dtype (Union[str, torch.dtype], optional, defaults to None):): + Converts the model weights to `dtype`, if specified. Strings get + converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`). + Use `dtype="auto"` to derive the type from the model's weights. device (Union[int, str]): device to use for model training. quantization_config (Optional[BitsAndBytesConfig]): quantization configuration for the model. Needed for 4-bit and 8-bit precision. - load_in_8bit (bool): Whether to load the model in 8-bit precision. - load_in_4bit (bool): Whether to load the model in 4-bit precision. trust_remote_code (bool): Whether to trust remote code during model loading. @@ -136,8 +121,6 @@ class BaseModelConfig: dtype: Optional[Union[str, torch.dtype]] = None device: Union[int, str] = "cuda" quantization_config: Optional[BitsAndBytesConfig] = None - load_in_8bit: bool = None - load_in_4bit: bool = None trust_remote_code: bool = False def __post_init__(self): @@ -270,74 +253,82 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None] ValueError: If a base model is not specified when using delta weights or adapter weights. ValueError: If a base model is specified when not using delta weights or adapter weights. """ - if args.inference_server_address is not None and args.model_args is not None: - raise ValueError("You cannot both use an inference server and load a model from its checkpoint.") - if args.inference_server_address is not None and args.endpoint_model_name is not None: - raise ValueError("You cannot both use a local inference server and load a model from an inference endpoint.") - if args.endpoint_model_name is not None and args.model_args is not None: - raise ValueError("You cannot both load a model from its checkpoint and from an inference endpoint.") - - # TGI - if args.inference_server_address is not None: - return TGIModelConfig( - inference_server_address=args.inference_server_address, inference_server_auth=args.inference_server_auth - ) + if args.model_args: + args_dict = {k.split("=")[0]: k.split("=")[1] for k in args.model_args.split(",")} + args_dict["accelerator"] = accelerator - # Endpoint - if args.endpoint_model_name: - if args.reuse_existing or args.vendor is not None: - model = args.endpoint_model_name.split("/")[1].replace(".", "-").lower() - return InferenceEndpointModelConfig( - name=f"{model}-lighteval", - repository=args.endpoint_model_name, - accelerator=args.accelerator, - region=args.region, - vendor=args.vendor, - instance_size=args.instance_size, - instance_type=args.instance_type, - should_reuse_existing=args.reuse_existing, - model_dtype=args.model_dtype, - revision=args.revision or "main", + return BaseModelConfig(**args_dict) + + with open(args.model_config_path, "r") as f: + config = yaml.safe_load(f)["model"] + + if config["type"] == "tgi": + return TGIModelConfig( + inference_server_address=args["instance"]["inference_server_address"], + inference_server_auth=args["instance"]["inference_server_auth"], ) - return InferenceModelConfig(model=args.endpoint_model_name) - - # Base - multichoice_continuations_start_space = args.multichoice_continuations_start_space - if not multichoice_continuations_start_space and not args.no_multichoice_continuations_start_space: - multichoice_continuations_start_space = None - if args.multichoice_continuations_start_space and args.no_multichoice_continuations_start_space: - raise ValueError( - "You cannot force both the multichoice continuations to start with a space and not to start with a space" - ) - if "load_in_4bit=True" in args.model_args: - quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) - elif "load_in_8bit=True" in args.model_args: - quantization_config = BitsAndBytesConfig(load_in_8bit=True) - else: - quantization_config = None - - # We extract the model args - args_dict = {k.split("=")[0]: k.split("=")[1] for k in args.model_args.split(",")} - # We store the relevant other args - args_dict["base_model"] = args.base_model - args_dict["batch_size"] = args.override_batch_size - args_dict["accelerator"] = accelerator - args_dict["quantization_config"] = quantization_config - args_dict["dtype"] = args.model_dtype - args_dict["multichoice_continuations_start_space"] = multichoice_continuations_start_space - - # Keeping only non null params - args_dict = {k: v for k, v in args_dict.items() if v is not None} - - if args.delta_weights: - if args.base_model is None: - raise ValueError("You need to specify a base model when using delta weights") - return DeltaModelConfig(**args_dict) - if args.adapter_weights: - if args.base_model is None: - raise ValueError("You need to specify a base model when using adapter weights") - return AdapterModelConfig(**args_dict) - if args.base_model is not None: - raise ValueError("You can't specifify a base model if you are not using delta/adapter weights") - return BaseModelConfig(**args_dict) + if config["type"] == "endpoint": + reuse_existing_endpoint = config["base_params"]["reuse_existing"] + complete_config_endpoint = all(val not in [None, ""] for val in config["instance"].values()) + if reuse_existing_endpoint or complete_config_endpoint: + return InferenceEndpointModelConfig( + name=config["base_params"]["endpoint_name"].replace(".", "-").lower(), + repository=config["base_params"]["model"], + model_dtype=config["base_params"]["dtype"], + revision=config["base_params"]["revision"] or "main", + should_reuse_existing=reuse_existing_endpoint, + accelerator=config["instance"]["accelerator"], + region=config["instance"]["region"], + vendor=config["instance"]["vendor"], + instance_size=config["instance"]["instance_size"], + instance_type=config["instance"]["instance_type"], + ) + return InferenceModelConfig(model=config["base_params"]["endpoint_name"]) + + if config["type"] == "base": + # Tests on the multichoice space parameters + multichoice_continuations_start_space = config["generation"]["multichoice_continuations_start_space"] + no_multichoice_continuations_start_space = config["generation"]["no_multichoice_continuations_start_space"] + if not multichoice_continuations_start_space and not no_multichoice_continuations_start_space: + multichoice_continuations_start_space = None + if multichoice_continuations_start_space and no_multichoice_continuations_start_space: + raise ValueError( + "You cannot force both the multichoice continuations to start with a space and not to start with a space" + ) + + # Creating optional quantization configuration + if config["base_params"]["dtype"] == "4bit": + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) + elif config["base_params"]["dtype"] == "8bit": + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + else: + quantization_config = None + + # We extract the model args + args_dict = {k.split("=")[0]: k.split("=")[1] for k in config["base_params"]["model_args"].split(",")} + + # We store the relevant other args + args_dict["base_model"] = config["merged_weights"]["base_model"] + args_dict["dtype"] = config["base_params"]["dtype"] + args_dict["accelerator"] = accelerator + args_dict["quantization_config"] = quantization_config + args_dict["batch_size"] = args.override_batch_size + args_dict["multichoice_continuations_start_space"] = multichoice_continuations_start_space + + # Keeping only non null params + args_dict = {k: v for k, v in args_dict.items() if v is not None} + + if config["merged_weights"]["delta_weights"]: + if config["merged_weights"]["base_model"] is None: + raise ValueError("You need to specify a base model when using delta weights") + return DeltaModelConfig(**args_dict) + if config["merged_weights"]["adapter_weights"]: + if config["merged_weights"]["base_model"] is None: + raise ValueError("You need to specify a base model when using adapter weights") + return AdapterModelConfig(**args_dict) + if config["merged_weights"]["base_model"] not in ["", None]: + raise ValueError("You can't specifify a base model if you are not using delta/adapter weights") + return BaseModelConfig(**args_dict) + + raise ValueError(f"Unknown model type in your model config file: {config['type']}") diff --git a/src/lighteval/models/utils.py b/src/lighteval/models/utils.py index ba9681515..ec4a22758 100644 --- a/src/lighteval/models/utils.py +++ b/src/lighteval/models/utils.py @@ -22,17 +22,13 @@ import os from itertools import islice -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, Union import torch from huggingface_hub import HfApi from transformers import AutoConfig -if TYPE_CHECKING: - from lighteval.models.model_config import BaseModelConfig - - def _get_dtype(dtype: Union[str, torch.dtype], config: Optional[AutoConfig] = None) -> torch.dtype: """ Get the torch dtype based on the input arguments. @@ -45,13 +41,12 @@ def _get_dtype(dtype: Union[str, torch.dtype], config: Optional[AutoConfig] = No torch.dtype: The torch dtype based on the input arguments. """ - if dtype is None and config is not None: - # For quantized models + if config is not None: # For quantized models if hasattr(config, "quantization_config"): _torch_dtype = None # must be inferred else: _torch_dtype = config.torch_dtype - elif isinstance(dtype, str) and dtype != "auto": + elif isinstance(dtype, str) and dtype not in ["auto", "4bit", "8bit"]: # Convert `str` args torch dtype: `float16` -> `torch.float16` _torch_dtype = getattr(torch, dtype) else: @@ -87,26 +82,6 @@ def _simplify_name(name_or_path: str) -> str: return name_or_path -def _get_precision(config: "BaseModelConfig", model_auto_config: AutoConfig): - """ - Returns the precision of the model. - - Args: - dtype (Union[str, torch.dtype]): The data type of the model. - load_in_8bit (bool): Whether to load the weights in 8-bit precision. - load_in_4bit (bool): Whether to load the weights in 4-bit precision. - - Returns: - str: The selected precision for loading the model weights. - """ - if config.load_in_8bit: - return "8bit" - elif config.load_in_4bit: - return "4bit" - else: - return _get_dtype(config.dtype, model_auto_config) - - def _get_model_sha(repo_id: str, revision: str): api = HfApi() try: