diff --git a/examples/model_configs/transformers_model.yaml b/examples/model_configs/transformers_model.yaml index e55e79c4f..7b7230174 100644 --- a/examples/model_configs/transformers_model.yaml +++ b/examples/model_configs/transformers_model.yaml @@ -5,7 +5,13 @@ model_parameters: compile: false model_parallel: false batch_size: 1 - multichoice_continuations_start_space: null # If true/false, will force multiple choice continuations to start/not start with a space. If none, will do nothing + continuous_batching: false + model_loading_kwargs: + attn_implementation: "eager" + #tp_plan: "auto" generation_parameters: + #num_blocks: 4096 + #block_size: 64 + #max_new_tokens: 256 temperature: 0.0 top_p: 0.9 diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index d4e3d2bd2..2d8a53fcb 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -25,6 +25,9 @@ class GenerationParameters(BaseModel, extra="forbid"): + num_blocks: NonNegativeInt | None = None # transformers + block_size: NonNegativeInt | None = None # transformers + early_stopping: bool | None = None # transformers repetition_penalty: NonNegativeFloat | None = None # vllm, transformers, tgi, sglang frequency_penalty: NonNegativeFloat | None = None # vllm, tgi, sglang @@ -186,6 +189,8 @@ def to_transformers_dict(self) -> dict: "repetition_penalty": self.repetition_penalty, "length_penalty": self.length_penalty, "output_scores": True, + "num_blocks": self.num_blocks, + "block_size": self.block_size, "return_dict_in_generate": True, } return {k: v for k, v in args.items() if v is not None} diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index 9650f7956..fef8b0b5b 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -23,7 +23,7 @@ import logging import os from datetime import timedelta -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -41,6 +41,7 @@ BitsAndBytesConfig, PretrainedConfig, ) +from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.utils import GenerateOutput from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES @@ -108,6 +109,8 @@ class TransformersModelConfig(ModelConfig): True forces adding space, False removes leading space if present. pairwise_tokenization (bool): Whether to tokenize context and continuation separately or together. Defaults to False. + continuous_batching (bool): + Whether to use continuous batching for generation. Defaults to False. Example: ```python @@ -143,6 +146,7 @@ class TransformersModelConfig(ModelConfig): compile: bool = False multichoice_continuations_start_space: bool | None = None pairwise_tokenization: bool = False + continuous_batching: bool = False def model_post_init(self, __context): if self.multichoice_continuations_start_space is True: @@ -185,7 +189,9 @@ def __init__( self._add_special_tokens = config.add_special_tokens or False self.pairwise_tokenization = config.pairwise_tokenization self.batch_size = config.batch_size + self.continuous_batching = config.continuous_batching self.transformers_config = config.get_transformers_config() + self.generation_config_dict = config.generation_parameters.to_transformers_dict() self.model_sha = config.get_model_sha() self._max_length = self._init_max_length() @@ -206,8 +212,6 @@ def __init__( self.model_name = _simplify_name(config.model_name) - self.generation_config_dict = config.generation_parameters.to_transformers_dict() - if is_accelerate_available(): model_size, _ = calculate_maximum_sizes(self.model) model_size = convert_bytes(model_size) @@ -252,14 +256,15 @@ def from_model( # Instanciate the object without using __init__ self = cls.__new__(cls) - self.config = config self.transformers_config = model.config - self.generation_config_dict = config.generation_parameters.to_transformers_dict() + self.config = config if config is not None else TransformersModelConfig(model_name=model.config.name_or_path) + if config is not None: + self.generation_config_dict = config.generation_parameters.to_transformers_dict() self._max_length = self._init_max_length() self._tokenizer = self._create_auto_tokenizer() - self.batch_size = config.batch_size + self.batch_size = getattr(config, "batch_size", None) self.model_name = _simplify_name(model.name_or_path) - self.model_sha = config.get_model_sha() + self.model_sha = self.config.get_model_sha() # If model_parallel is not set we compare the number of processes with the number of GPUs self.model = model @@ -398,6 +403,11 @@ def _create_auto_model(self) -> transformers.PreTrainedModel: # model.to(self.device) model.eval() torch.set_grad_enabled(False) + if self.continuous_batching: + generation_config = GenerationConfig( + **self.generation_config_dict, + ) + model.generation_config = generation_config if self.config.compile: try: @@ -500,7 +510,110 @@ def forward_batch(batch_size): logger.info(f"Determined largest batch size: {batch_size}") return batch_size - def greedy_until( + def _continuous_greedy_until( + self, + docs: list[Doc], + ) -> list[ModelResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerateReturn]: list of generated responses. + """ + dataset = GenerativeTaskDataset(requests=docs, num_dataset_splits=self.DATASET_SPLITS) + results = [] + + for split in tqdm( + dataset.splits_iterator(), + total=dataset.num_dataset_splits, + desc="Splits", + position=0, + disable=False, # self.disable_tqdm, + ): + # For chat models, generation stops with EOS token, so we don't need to specify stop tokens + if self.use_chat_template: + stop_tokens = [] + else: + # NOTE: we are assuming all items in a batch behave similarly (same + # stop_tokens and max_tokens genrated) which is not necessarily + # the case! Because of that we only use batch size of 1 + stop_tokens = split[0].stop_sequence + + max_new_tokens = self.config.generation_parameters.max_new_tokens or split[0].generation_size + returns_logits = split[0].use_logits + num_samples = split[0].num_samples + contexts = [self.prompt_manager.prepare_prompt(doc) for doc in split] + tokenized = self.tokenizer(contexts, add_special_tokens=self.add_special_tokens) + + # The main question for this step is the following: + # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk + # of losing some meaning, or have some generations that are exceedingly short? + # The choice we go for here is to avoid truncating the prompt if we can, since it + # should have been managed by the prompt creator/few shot manager if requested by the user. + inputs = tokenized["input_ids"] + context_size = len(inputs[0]) + + # left truncate the inputs to the maximum length + if max_new_tokens is not None: + if context_size + max_new_tokens > self.max_length: + logger.warning( + f"{context_size + max_new_tokens=} which is greater than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens." + ) + context_size = self.max_length - max_new_tokens + if context_size < 0: + logger.critical( + f"{context_size=} is less than 0, either reduce the max_new_tokens or increase model max length." + ) + raise ValueError("Context size is less than 0.") + inputs = [input[-context_size:] for input in inputs] + else: + if context_size > self.max_length: + logger.warning( + f"{context_size=} which is greater than {self.max_length=}. Truncating context to {self.max_length} tokens." + ) + context_size = self.max_length + inputs = [input[-context_size:] for input in inputs] + + _outputs = self._generate( + inputs=inputs, + max_new_tokens=max_new_tokens, + stop_tokens=stop_tokens, + returns_logits=returns_logits, + num_samples=num_samples, + continuous_batching=True, + ) + + for req_id, _output in _outputs.items(): + output_token_ids = [] + logprobs_raw = [] + result = [] + + # for output in _output.outputs: + output_token_ids.append(_output.generated_tokens) + # logprobs_raw.append(output.logprobs) + result.append(self.tokenizer.decode(_output.generated_tokens)) + + if logprobs_raw and output_token_ids and False: + logprobs = [logprobs_raw[0][token_id].logprob for token_id in output_token_ids[0]] + else: + logprobs = [] + + input_token_ids = _output.prompt_ids + cur_response = ModelResponse( + text=result, + logprobs=logprobs, + output_tokens=output_token_ids, + input_tokens=input_token_ids, + ) + results.append(cur_response) + + return dataset.get_original_order(results) + + def _padded_greedy_until( self, docs: list[Doc], ) -> list[ModelResponse]: @@ -613,12 +726,43 @@ def greedy_until( stop_tokens=stop_tokens, returns_logits=False, num_samples=num_samples, + continuous_batching=False, ) results.extend(cur_reponses) return dataset.get_original_order(results) - def _generate( + def greedy_until( + self, + docs: list[Doc], + ) -> list[ModelResponse]: + if self.continuous_batching: + return self._continuous_greedy_until(docs) + else: + return self._padded_greedy_until(docs) + + def _generate_continuous( + self, + inputs: list[list[int]], + max_new_tokens: Optional[int] = None, + stop_tokens: Optional[list[str]] = None, + returns_logits: Optional[bool] = False, + num_samples: int = 1, + generate: bool = True, + ) -> Dict[str, ModelResponse]: + # Compute model generation + self.model.generation_config.use_cuda_graph = False # Disable CUDA graph for batch generation + self.model.generation_config.max_batch_tokens = 256 # Disable CUDA graph for batch generation + # self.model.generation_config.do_sample = False # Disable CUDA graph for batch generation + batch_outputs = self.model.generate_batch( + inputs=inputs, + generation_config=self.model.generation_config, + # You can pass request-specific overrides here, e.g., max_new_tokens=100 + ) + + return batch_outputs + + def _generate_padded( self, batch: Batch, max_new_tokens: int, @@ -704,6 +848,16 @@ def _generate( return all_responses + def _generate( + self, + continuous_batching: bool, + **kwargs, + ) -> list[ModelResponse]: + if continuous_batching: + return self._generate_continuous(**kwargs) + else: + return self._generate_padded(**kwargs) + def loglikelihood( self, docs: list[Doc], diff --git a/tests/models/endpoints/test_endpoint_model.py b/tests/models/endpoints/test_endpoint_model.py index 820a23327..5b3aa7563 100644 --- a/tests/models/endpoints/test_endpoint_model.py +++ b/tests/models/endpoints/test_endpoint_model.py @@ -52,6 +52,8 @@ class TestInferenceEndpointModelConfig: "add_special_tokens": True, "system_prompt": None, "generation_parameters": { + "num_blocks": None, + "block_size": None, "early_stopping": None, "frequency_penalty": None, "length_penalty": None, diff --git a/tests/models/endpoints/test_tgi_model.py b/tests/models/endpoints/test_tgi_model.py index 93184d5a4..895871597 100644 --- a/tests/models/endpoints/test_tgi_model.py +++ b/tests/models/endpoints/test_tgi_model.py @@ -38,6 +38,8 @@ class TestTGIModelConfig: "model_name": None, "system_prompt": None, "generation_parameters": { + "block_size": None, + "num_blocks": None, "early_stopping": None, "frequency_penalty": None, "length_penalty": None,