Skip to content

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented May 12, 2025

What does this PR do?

Snippet:

import time

import datasets
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig


torch.set_float32_matmul_precision("high")

model_id = "meta-llama/Llama-3.2-3b-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id, attn_implementation="paged_attention", torch_dtype=torch.bfloat16, device_map="auto"
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")

generation_config = GenerationConfig(
    max_new_tokens=512,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    use_cache=False,
    num_blocks=2048,
    block_size=128,
    do_sample=True,
    max_batch_tokens=1024,  # Maximum number of tokens to process in a single batch
)

train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")

# --- Example 1: Simple Version using generate_batch ---
print("--- Running CB Generation Example ---")


def tokenize_function(examples):
    return tokenizer(examples["question"])


tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]

start_time_simple = time.time()
# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True)
batch_outputs = model.generate_batch(
    inputs=simple_batch_inputs,
    generation_config=generation_config,
)
end_time_simple = time.time()

for request in batch_outputs:
    input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
    try:
        output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
    except Exception as e:
        print(f"Decoding failed for request {request}: {e}")
        output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False)
    if len(output_text) > 0:
        print("-" * 20)
        print(f"{request} Input:  {input_text}")
        print(f"{request} Output: {output_text}")
    else:
        print("", end="\r\r\r\r")
print("-" * 20)
print("--- Finished CB Generation Example ---\n\n")


print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds")



class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
class PreTrainedModel(nn.Module, ModuleUtilsMixin, ContinuousMixin, PushToHubMixin, PeftAdapterMixin):
Copy link
Member

@gante gante May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add it instead as a mixin to GenerationMixin? 👀

rationale: many PreTrainedModel-derived instances can't generate, all GenerationMixin-derived instances can. That way, we spare those models of any additional requirement, present and future. We also protect users from exceptions :D

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay does not sound bad ye[

@ArthurZucker ArthurZucker marked this pull request as ready for review May 22, 2025 14:53
@ArthurZucker ArthurZucker merged commit 211f2b0 into main May 22, 2025
18 of 21 checks passed
@ArthurZucker ArthurZucker deleted the feat/stream_inputs_to_continuous_batch branch May 22, 2025 15:43
@ydshieh ydshieh mentioned this pull request May 23, 2025

from ..generation.utils import RequestStatus

class RequestStatus(Enum):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you duplicate the RequestStatus enum @ArthurZucker ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import issue 🤣

redmoe-moutain pushed a commit to redmoe-moutain/transformers that referenced this pull request Jun 10, 2025
* stash for now

* initial commit

* small updated

* up

* up

* works!

* nits and fixes

* don't loop too much

* finish working example

* update

* fix the small freeblocks issue

* feat: stream inputs to continuous batch

* fix: update attn from `eager` to `sdpa`

* refactor: fmt

* refactor: cleanup unnecessary code

* feat: add `update` fn to `PagedAttentionCache`

* feat: broken optimal block size computation

* fix: debugging invalid cache logic

* fix: attention mask

* refactor: use custom prompts for example

* feat: add streaming output

* fix: prefill split

refactor: add doc strings and unsound/redundant logic
fix: compute optimal blocks logic

* fix: send decoded tokens when `prefilling_split` -> `decoding`

* refactor: move logic to appropriate parent class

* fix: remove truncation as we split prefilling anyways

refactor: early return when we have enough selected requests

* feat: add paged attention forward

* push Ggraoh>

* add paged sdpa

* update

* btter mps defaults

* feat: add progress bar for `generate_batch`

* feat: add opentelemetry metrics (ttft + batch fill %age)

* feat: add tracing

* Add cuda graphs (huggingface#38059)

* draft cudagraphs addition

* nits

* styling

* update

* fix

* kinda draft of what it should look like

* fixes

* lol

* not sure why inf everywhere

* can generate but output is shit

* some fixes

* we should have a single device synch

* broken outputs but it does run

* refactor

* updates

* updates with some fixes

* fix mask causality

* another commit that casts after

* add error

* simplify example

* update

* updates

* revert llama changes

* fix merge conflicts

* fix: tracing and metrics

* my updates

* update script default values

* fix block allocation issue

* fix prefill split attnetion mask

* no bugs

* add paged eager

* fix

* update

* style

* feat: add pytorch traces

* fix

* fix

* refactor: remove pytorch profiler data

* style

* nits

* cleanup

* draft test file

* fix

* fix

* fix paged and graphs

* small renamings

* cleanups and push

* refactor: move tracing and metrics logic to utils

* refactor: trace more blocks of code

* nits

* nits

* update

* to profile or not to profile

* refactor: create new output object

* causal by default

* cleanup but generations are still off for IDK what reason

* simplifications but not running still

* this does work.

* small quality of life updates

* nits

* updaet

* fix the scheduler

* fix warning

* ol

* fully fixed

* nits

* different generation parameters

* nice

* just style

* feat: add cache memory usage

* feat: add kv cache free memory

* feat: add active/waiting count & req latency

* do the sampling

* fix: synchronize CUDA only if available and improve error handling in ContinuousBatchingManager

* fix on mps

* feat: add dashboard & histogram buckets

* perf: improve waiting reqs data structures

* attempt to compile, but we should only do it on mps AFAIK

* feat: decouple scheduling logic

* just a draft

* c;eanup and fixup

* optional

* style

* update

* update

* remove the draft documentation

* fix import as well

* update

* fix the test

* style doomed

---------

Co-authored-by: Luc Georges <[email protected]>
clefourrier added a commit to huggingface/lighteval that referenced this pull request Aug 1, 2025
Add necessary changes to call generate with CB
Linked PR: huggingface/transformers#38085
This works:
```python
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.pipeline import Pipeline, PipelineParameters, ParallelismManager
from lighteval.models.endpoints.inference_providers_model import (
    InferenceProvidersModelConfig,
)
from lighteval.models.transformers.transformers_model import TransformersModel
import torch
from transformers import AutoModelForCausalLM, GenerationConfig

MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
PROVIDER = "hf-inference"
BENCHMARKS = "lighteval|gsm8k|0|0"

evaluation_tracker = EvaluationTracker(output_dir="./results")
pipeline_params = PipelineParameters(
    use_chat_template=True, launcher_type=ParallelismManager.NONE, max_samples=None
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
)

# Configure generation parameters
generation_config = GenerationConfig(
    max_new_tokens=10,
    eos_token_id=model.config.eos_token_id,
    pad_token_id=model.config.pad_token_id,
    num_blocks=2048,
    block_size=256,
)
model.generation_config = generation_config
model = TransformersModel.from_model(model)
pipeline = Pipeline(
    model=model,
    pipeline_parameters=pipeline_params,
    evaluation_tracker=evaluation_tracker,
    tasks=BENCHMARKS,
)

pipeline.evaluate()
results = pipeline.get_results()["results"]
print(results)
```

---------

Co-authored-by: Arthur Zucker <[email protected]>
Co-authored-by: Clémentine Fourrier <[email protected]>
NathanHB added a commit to huggingface/lighteval that referenced this pull request Sep 19, 2025
Add necessary changes to call generate with CB
Linked PR: huggingface/transformers#38085
This works:
```python
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.pipeline import Pipeline, PipelineParameters, ParallelismManager
from lighteval.models.endpoints.inference_providers_model import (
    InferenceProvidersModelConfig,
)
from lighteval.models.transformers.transformers_model import TransformersModel
import torch
from transformers import AutoModelForCausalLM, GenerationConfig

MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
PROVIDER = "hf-inference"
BENCHMARKS = "lighteval|gsm8k|0|0"

evaluation_tracker = EvaluationTracker(output_dir="./results")
pipeline_params = PipelineParameters(
    use_chat_template=True, launcher_type=ParallelismManager.NONE, max_samples=None
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
)

# Configure generation parameters
generation_config = GenerationConfig(
    max_new_tokens=10,
    eos_token_id=model.config.eos_token_id,
    pad_token_id=model.config.pad_token_id,
    num_blocks=2048,
    block_size=256,
)
model.generation_config = generation_config
model = TransformersModel.from_model(model)
pipeline = Pipeline(
    model=model,
    pipeline_parameters=pipeline_params,
    evaluation_tracker=evaluation_tracker,
    tasks=BENCHMARKS,
)

pipeline.evaluate()
results = pipeline.get_results()["results"]
print(results)
```

---------

Co-authored-by: Arthur Zucker <[email protected]>
Co-authored-by: Clémentine Fourrier <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants