-
Notifications
You must be signed in to change notification settings - Fork 359
Adds continuous batching #850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… into add-fast-generate
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces continuous batching support for Transformer-based models, enabling split-wise streaming generation.
- Adds
continuous_batching
flag throughout configuration, model initialization, and generation functions. - Implements a new
_continuous_greedy_until
path and refactors_generate
to dispatch based on the flag. - Updates
GenerationParameters
and example configs to includenum_blocks
andblock_size
, and adjusts tests accordingly.
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
tests/models/endpoints/test_tgi_model.py | Inserts block_size and num_blocks into generation parameters |
tests/models/endpoints/test_endpoint_model.py | Inserts num_blocks and block_size into generation parameters |
src/lighteval/models/transformers/transformers_model.py | Propagates continuous_batching through init, from_model, and generate paths |
src/lighteval/models/model_input.py | Extends GenerationParameters with num_blocks and block_size |
examples/model_configs/transformers_model.yaml | Adds continuous_batching and example num_blocks /block_size |
Comments suppressed due to low confidence (2)
src/lighteval/models/model_input.py:28
- [nitpick] New fields
num_blocks
andblock_size
inGenerationParameters
lack descriptions in the class docstring. Consider documenting their purpose and effects.
num_blocks: NonNegativeInt | None = None # transformers
src/lighteval/models/transformers/transformers_model.py:114
- There are no existing tests covering the new
continuous_batching
logic path. Consider adding unit tests to verify bothTrue
andFalse
behaviors.
continuous_batching (bool):
else: | ||
return self._padded_greedy_until(docs) | ||
|
||
def _generate_fast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is generate fast for continuous batching only? if yes -> call it generate_continuous then, since the other is generate_padded and not generate_slow (for homogeneity)
return batch_size | ||
|
||
def greedy_until( | ||
def _continuous_greedy_until( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anyway to factorize more between continuous and padded greedy until? (other wise, there's a risk we end up having different input management for example, like we had in the past across generation models)
Content from original PR: from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.pipeline import Pipeline, PipelineParameters, ParallelismManager
from lighteval.models.endpoints.inference_providers_model import (
InferenceProvidersModelConfig,
)
from lighteval.models.transformers.transformers_model import TransformersModel
import torch
from transformers import AutoModelForCausalLM, GenerationConfig
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
PROVIDER = "hf-inference"
BENCHMARKS = "lighteval|gsm8k|0|0"
evaluation_tracker = EvaluationTracker(output_dir="./results")
pipeline_params = PipelineParameters(
use_chat_template=True, launcher_type=ParallelismManager.NONE, max_samples=None
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
)
# Configure generation parameters
generation_config = GenerationConfig(
max_new_tokens=10,
eos_token_id=model.config.eos_token_id,
pad_token_id=model.config.pad_token_id,
num_blocks=2048,
block_size=256,
)
model.generation_config = generation_config
model = TransformersModel.from_model(model)
pipeline = Pipeline(
model=model,
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
tasks=BENCHMARKS,
)
pipeline.evaluate()
results = pipeline.get_results()["results"]
print(results) |
Does not work on my side 😢 I might have done something wrong tho! |
I'll debug this week |
Add necessary changes to call generate with CB Linked PR: huggingface/transformers#38085 This works: ```python from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.pipeline import Pipeline, PipelineParameters, ParallelismManager from lighteval.models.endpoints.inference_providers_model import ( InferenceProvidersModelConfig, ) from lighteval.models.transformers.transformers_model import TransformersModel import torch from transformers import AutoModelForCausalLM, GenerationConfig MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" PROVIDER = "hf-inference" BENCHMARKS = "lighteval|gsm8k|0|0" evaluation_tracker = EvaluationTracker(output_dir="./results") pipeline_params = PipelineParameters( use_chat_template=True, launcher_type=ParallelismManager.NONE, max_samples=None ) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" ) # Configure generation parameters generation_config = GenerationConfig( max_new_tokens=10, eos_token_id=model.config.eos_token_id, pad_token_id=model.config.pad_token_id, num_blocks=2048, block_size=256, ) model.generation_config = generation_config model = TransformersModel.from_model(model) pipeline = Pipeline( model=model, pipeline_parameters=pipeline_params, evaluation_tracker=evaluation_tracker, tasks=BENCHMARKS, ) pipeline.evaluate() results = pipeline.get_results()["results"] print(results) ``` --------- Co-authored-by: Arthur Zucker <[email protected]> Co-authored-by: Clémentine Fourrier <[email protected]>
No description provided.