-
Notifications
You must be signed in to change notification settings - Fork 362
Adds continuous batching #850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
41838c0
update for CB
ArthurZucker f7a3c2f
update
ArthurZucker c9b3467
push
ArthurZucker 796ef5a
Merge branch 'main' into add-fast-generate
clefourrier a7e2751
c'est une honte, 0.2.... ruff....
ArthurZucker a1c4c00
Merge branch 'add-fast-generate' of github.com:ArthurZucker/lighteval…
ArthurZucker 2b162f7
Merge branch 'main' into add-fast-generate
clefourrier 101083e
Merge branch 'main' into add-fast-generate
NathanHB 0f772b1
merge main
NathanHB df98d9b
fix model
NathanHB 1da56bd
fix model
NathanHB fe6f24c
fix tests
NathanHB 96466e4
fix slow tests
NathanHB 8344961
fix slow tests
NathanHB 7453c6f
reset vllm model file config
NathanHB 84237ad
Merge branch 'main' into nathan-add-continious-batching
clefourrier c179876
Apply suggestions from code review
clefourrier File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
import logging | ||
import os | ||
from datetime import timedelta | ||
from typing import Optional, Tuple, Union | ||
from typing import Dict, Optional, Tuple, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
@@ -41,6 +41,7 @@ | |
BitsAndBytesConfig, | ||
PretrainedConfig, | ||
) | ||
from transformers.generation.configuration_utils import GenerationConfig | ||
from transformers.generation.utils import GenerateOutput | ||
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES | ||
|
||
|
@@ -108,6 +109,8 @@ class TransformersModelConfig(ModelConfig): | |
True forces adding space, False removes leading space if present. | ||
pairwise_tokenization (bool): | ||
Whether to tokenize context and continuation separately or together. Defaults to False. | ||
continuous_batching (bool): | ||
Whether to use continuous batching for generation. Defaults to False. | ||
|
||
Example: | ||
```python | ||
|
@@ -143,6 +146,7 @@ class TransformersModelConfig(ModelConfig): | |
compile: bool = False | ||
multichoice_continuations_start_space: bool | None = None | ||
pairwise_tokenization: bool = False | ||
continuous_batching: bool = False | ||
|
||
def model_post_init(self, __context): | ||
if self.multichoice_continuations_start_space is True: | ||
|
@@ -185,7 +189,9 @@ def __init__( | |
self._add_special_tokens = config.add_special_tokens or False | ||
self.pairwise_tokenization = config.pairwise_tokenization | ||
self.batch_size = config.batch_size | ||
self.continuous_batching = config.continuous_batching | ||
self.transformers_config = config.get_transformers_config() | ||
self.generation_config_dict = config.generation_parameters.to_transformers_dict() | ||
|
||
self.model_sha = config.get_model_sha() | ||
self._max_length = self._init_max_length() | ||
|
@@ -206,8 +212,6 @@ def __init__( | |
|
||
self.model_name = _simplify_name(config.model_name) | ||
|
||
self.generation_config_dict = config.generation_parameters.to_transformers_dict() | ||
|
||
if is_accelerate_available(): | ||
model_size, _ = calculate_maximum_sizes(self.model) | ||
model_size = convert_bytes(model_size) | ||
|
@@ -252,14 +256,15 @@ def from_model( | |
|
||
# Instanciate the object without using __init__ | ||
self = cls.__new__(cls) | ||
self.config = config | ||
self.transformers_config = model.config | ||
self.generation_config_dict = config.generation_parameters.to_transformers_dict() | ||
self.config = config if config is not None else TransformersModelConfig(model_name=model.config.name_or_path) | ||
if config is not None: | ||
self.generation_config_dict = config.generation_parameters.to_transformers_dict() | ||
self._max_length = self._init_max_length() | ||
self._tokenizer = self._create_auto_tokenizer() | ||
self.batch_size = config.batch_size | ||
self.batch_size = getattr(config, "batch_size", None) | ||
self.model_name = _simplify_name(model.name_or_path) | ||
self.model_sha = config.get_model_sha() | ||
self.model_sha = self.config.get_model_sha() | ||
|
||
# If model_parallel is not set we compare the number of processes with the number of GPUs | ||
self.model = model | ||
|
@@ -398,6 +403,11 @@ def _create_auto_model(self) -> transformers.PreTrainedModel: | |
# model.to(self.device) | ||
model.eval() | ||
torch.set_grad_enabled(False) | ||
if self.continuous_batching: | ||
generation_config = GenerationConfig( | ||
**self.generation_config_dict, | ||
) | ||
model.generation_config = generation_config | ||
|
||
if self.config.compile: | ||
try: | ||
|
@@ -500,7 +510,110 @@ def forward_batch(batch_size): | |
logger.info(f"Determined largest batch size: {batch_size}") | ||
return batch_size | ||
|
||
def greedy_until( | ||
def _continuous_greedy_until( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there anyway to factorize more between continuous and padded greedy until? (other wise, there's a risk we end up having different input management for example, like we had in the past across generation models) |
||
self, | ||
docs: list[Doc], | ||
) -> list[ModelResponse]: | ||
""" | ||
Generates responses using a greedy decoding strategy until certain ending conditions are met. | ||
|
||
Args: | ||
requests (list[Request]): list of requests containing the context and ending conditions. | ||
override_bs (int, optional): Override the batch size for generation. Defaults to None. | ||
|
||
Returns: | ||
list[GenerateReturn]: list of generated responses. | ||
""" | ||
dataset = GenerativeTaskDataset(requests=docs, num_dataset_splits=self.DATASET_SPLITS) | ||
results = [] | ||
|
||
for split in tqdm( | ||
dataset.splits_iterator(), | ||
total=dataset.num_dataset_splits, | ||
desc="Splits", | ||
position=0, | ||
disable=False, # self.disable_tqdm, | ||
): | ||
# For chat models, generation stops with EOS token, so we don't need to specify stop tokens | ||
if self.use_chat_template: | ||
stop_tokens = [] | ||
else: | ||
# NOTE: we are assuming all items in a batch behave similarly (same | ||
# stop_tokens and max_tokens genrated) which is not necessarily | ||
# the case! Because of that we only use batch size of 1 | ||
stop_tokens = split[0].stop_sequence | ||
|
||
max_new_tokens = self.config.generation_parameters.max_new_tokens or split[0].generation_size | ||
returns_logits = split[0].use_logits | ||
num_samples = split[0].num_samples | ||
contexts = [self.prompt_manager.prepare_prompt(doc) for doc in split] | ||
tokenized = self.tokenizer(contexts, add_special_tokens=self.add_special_tokens) | ||
|
||
# The main question for this step is the following: | ||
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk | ||
# of losing some meaning, or have some generations that are exceedingly short? | ||
# The choice we go for here is to avoid truncating the prompt if we can, since it | ||
# should have been managed by the prompt creator/few shot manager if requested by the user. | ||
inputs = tokenized["input_ids"] | ||
context_size = len(inputs[0]) | ||
|
||
# left truncate the inputs to the maximum length | ||
if max_new_tokens is not None: | ||
if context_size + max_new_tokens > self.max_length: | ||
logger.warning( | ||
f"{context_size + max_new_tokens=} which is greater than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens." | ||
) | ||
context_size = self.max_length - max_new_tokens | ||
if context_size < 0: | ||
logger.critical( | ||
f"{context_size=} is less than 0, either reduce the max_new_tokens or increase model max length." | ||
) | ||
raise ValueError("Context size is less than 0.") | ||
inputs = [input[-context_size:] for input in inputs] | ||
else: | ||
if context_size > self.max_length: | ||
logger.warning( | ||
f"{context_size=} which is greater than {self.max_length=}. Truncating context to {self.max_length} tokens." | ||
) | ||
context_size = self.max_length | ||
inputs = [input[-context_size:] for input in inputs] | ||
|
||
_outputs = self._generate( | ||
inputs=inputs, | ||
max_new_tokens=max_new_tokens, | ||
stop_tokens=stop_tokens, | ||
returns_logits=returns_logits, | ||
num_samples=num_samples, | ||
continuous_batching=True, | ||
) | ||
|
||
for req_id, _output in _outputs.items(): | ||
output_token_ids = [] | ||
logprobs_raw = [] | ||
result = [] | ||
|
||
# for output in _output.outputs: | ||
output_token_ids.append(_output.generated_tokens) | ||
# logprobs_raw.append(output.logprobs) | ||
result.append(self.tokenizer.decode(_output.generated_tokens)) | ||
|
||
if logprobs_raw and output_token_ids and False: | ||
logprobs = [logprobs_raw[0][token_id].logprob for token_id in output_token_ids[0]] | ||
else: | ||
logprobs = [] | ||
|
||
input_token_ids = _output.prompt_ids | ||
cur_response = ModelResponse( | ||
text=result, | ||
logprobs=logprobs, | ||
output_tokens=output_token_ids, | ||
input_tokens=input_token_ids, | ||
) | ||
results.append(cur_response) | ||
|
||
return dataset.get_original_order(results) | ||
|
||
def _padded_greedy_until( | ||
self, | ||
docs: list[Doc], | ||
) -> list[ModelResponse]: | ||
|
@@ -613,12 +726,43 @@ def greedy_until( | |
stop_tokens=stop_tokens, | ||
returns_logits=False, | ||
num_samples=num_samples, | ||
continuous_batching=False, | ||
) | ||
results.extend(cur_reponses) | ||
|
||
return dataset.get_original_order(results) | ||
|
||
def _generate( | ||
def greedy_until( | ||
self, | ||
docs: list[Doc], | ||
) -> list[ModelResponse]: | ||
if self.continuous_batching: | ||
return self._continuous_greedy_until(docs) | ||
else: | ||
return self._padded_greedy_until(docs) | ||
|
||
def _generate_continuous( | ||
self, | ||
inputs: list[list[int]], | ||
max_new_tokens: Optional[int] = None, | ||
stop_tokens: Optional[list[str]] = None, | ||
returns_logits: Optional[bool] = False, | ||
num_samples: int = 1, | ||
generate: bool = True, | ||
) -> Dict[str, ModelResponse]: | ||
# Compute model generation | ||
self.model.generation_config.use_cuda_graph = False # Disable CUDA graph for batch generation | ||
self.model.generation_config.max_batch_tokens = 256 # Disable CUDA graph for batch generation | ||
# self.model.generation_config.do_sample = False # Disable CUDA graph for batch generation | ||
batch_outputs = self.model.generate_batch( | ||
inputs=inputs, | ||
generation_config=self.model.generation_config, | ||
# You can pass request-specific overrides here, e.g., max_new_tokens=100 | ||
) | ||
|
||
return batch_outputs | ||
|
||
def _generate_padded( | ||
self, | ||
batch: Batch, | ||
max_new_tokens: int, | ||
|
@@ -704,6 +848,16 @@ def _generate( | |
|
||
return all_responses | ||
|
||
def _generate( | ||
self, | ||
continuous_batching: bool, | ||
**kwargs, | ||
) -> list[ModelResponse]: | ||
if continuous_batching: | ||
return self._generate_continuous(**kwargs) | ||
else: | ||
return self._generate_padded(**kwargs) | ||
|
||
def loglikelihood( | ||
self, | ||
docs: list[Doc], | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.