-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Add CB #38085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CB #38085
Conversation
refactor: add doc strings and unsound/redundant logic fix: compute optimal blocks logic
refactor: early return when we have enough selected requests
…uggingface/transformers into feat/stream_inputs_to_continuous_batch
…uggingface/transformers into feat/stream_inputs_to_continuous_batch
… ContinuousBatchingManager
src/transformers/modeling_utils.py
Outdated
|
||
|
||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin): | ||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, ContinuousMixin, PushToHubMixin, PeftAdapterMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add it instead as a mixin to GenerationMixin
? 👀
rationale: many PreTrainedModel
-derived instances can't generate, all GenerationMixin
-derived instances can. That way, we spare those models of any additional requirement, present and future. We also protect users from exceptions :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay does not sound bad ye[
|
||
from ..generation.utils import RequestStatus | ||
|
||
class RequestStatus(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you duplicate the RequestStatus
enum @ArthurZucker ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import issue 🤣
* stash for now * initial commit * small updated * up * up * works! * nits and fixes * don't loop too much * finish working example * update * fix the small freeblocks issue * feat: stream inputs to continuous batch * fix: update attn from `eager` to `sdpa` * refactor: fmt * refactor: cleanup unnecessary code * feat: add `update` fn to `PagedAttentionCache` * feat: broken optimal block size computation * fix: debugging invalid cache logic * fix: attention mask * refactor: use custom prompts for example * feat: add streaming output * fix: prefill split refactor: add doc strings and unsound/redundant logic fix: compute optimal blocks logic * fix: send decoded tokens when `prefilling_split` -> `decoding` * refactor: move logic to appropriate parent class * fix: remove truncation as we split prefilling anyways refactor: early return when we have enough selected requests * feat: add paged attention forward * push Ggraoh> * add paged sdpa * update * btter mps defaults * feat: add progress bar for `generate_batch` * feat: add opentelemetry metrics (ttft + batch fill %age) * feat: add tracing * Add cuda graphs (huggingface#38059) * draft cudagraphs addition * nits * styling * update * fix * kinda draft of what it should look like * fixes * lol * not sure why inf everywhere * can generate but output is shit * some fixes * we should have a single device synch * broken outputs but it does run * refactor * updates * updates with some fixes * fix mask causality * another commit that casts after * add error * simplify example * update * updates * revert llama changes * fix merge conflicts * fix: tracing and metrics * my updates * update script default values * fix block allocation issue * fix prefill split attnetion mask * no bugs * add paged eager * fix * update * style * feat: add pytorch traces * fix * fix * refactor: remove pytorch profiler data * style * nits * cleanup * draft test file * fix * fix * fix paged and graphs * small renamings * cleanups and push * refactor: move tracing and metrics logic to utils * refactor: trace more blocks of code * nits * nits * update * to profile or not to profile * refactor: create new output object * causal by default * cleanup but generations are still off for IDK what reason * simplifications but not running still * this does work. * small quality of life updates * nits * updaet * fix the scheduler * fix warning * ol * fully fixed * nits * different generation parameters * nice * just style * feat: add cache memory usage * feat: add kv cache free memory * feat: add active/waiting count & req latency * do the sampling * fix: synchronize CUDA only if available and improve error handling in ContinuousBatchingManager * fix on mps * feat: add dashboard & histogram buckets * perf: improve waiting reqs data structures * attempt to compile, but we should only do it on mps AFAIK * feat: decouple scheduling logic * just a draft * c;eanup and fixup * optional * style * update * update * remove the draft documentation * fix import as well * update * fix the test * style doomed --------- Co-authored-by: Luc Georges <[email protected]>
Add necessary changes to call generate with CB Linked PR: huggingface/transformers#38085 This works: ```python from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.pipeline import Pipeline, PipelineParameters, ParallelismManager from lighteval.models.endpoints.inference_providers_model import ( InferenceProvidersModelConfig, ) from lighteval.models.transformers.transformers_model import TransformersModel import torch from transformers import AutoModelForCausalLM, GenerationConfig MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" PROVIDER = "hf-inference" BENCHMARKS = "lighteval|gsm8k|0|0" evaluation_tracker = EvaluationTracker(output_dir="./results") pipeline_params = PipelineParameters( use_chat_template=True, launcher_type=ParallelismManager.NONE, max_samples=None ) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" ) # Configure generation parameters generation_config = GenerationConfig( max_new_tokens=10, eos_token_id=model.config.eos_token_id, pad_token_id=model.config.pad_token_id, num_blocks=2048, block_size=256, ) model.generation_config = generation_config model = TransformersModel.from_model(model) pipeline = Pipeline( model=model, pipeline_parameters=pipeline_params, evaluation_tracker=evaluation_tracker, tasks=BENCHMARKS, ) pipeline.evaluate() results = pipeline.get_results()["results"] print(results) ``` --------- Co-authored-by: Arthur Zucker <[email protected]> Co-authored-by: Clémentine Fourrier <[email protected]>
Add necessary changes to call generate with CB Linked PR: huggingface/transformers#38085 This works: ```python from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.pipeline import Pipeline, PipelineParameters, ParallelismManager from lighteval.models.endpoints.inference_providers_model import ( InferenceProvidersModelConfig, ) from lighteval.models.transformers.transformers_model import TransformersModel import torch from transformers import AutoModelForCausalLM, GenerationConfig MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" PROVIDER = "hf-inference" BENCHMARKS = "lighteval|gsm8k|0|0" evaluation_tracker = EvaluationTracker(output_dir="./results") pipeline_params = PipelineParameters( use_chat_template=True, launcher_type=ParallelismManager.NONE, max_samples=None ) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" ) # Configure generation parameters generation_config = GenerationConfig( max_new_tokens=10, eos_token_id=model.config.eos_token_id, pad_token_id=model.config.pad_token_id, num_blocks=2048, block_size=256, ) model.generation_config = generation_config model = TransformersModel.from_model(model) pipeline = Pipeline( model=model, pipeline_parameters=pipeline_params, evaluation_tracker=evaluation_tracker, tasks=BENCHMARKS, ) pipeline.evaluate() results = pipeline.get_results()["results"] print(results) ``` --------- Co-authored-by: Arthur Zucker <[email protected]> Co-authored-by: Clémentine Fourrier <[email protected]>
What does this PR do?
Snippet: