Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Jan 20, 2025

What does this PR do?

✨ Vectorizes beam search ✨

Fixes #34843
Fixes #35618
Fixes #34574
Groundwork for: #35561 #35451 #35436 #32097 #30810
Other related issues: #35023

⚠️ Potential friction points:

  1. This refactor removes the use of the BeamScorer inside _beam_search(). Advanced users relying on their custom BeamScorer instances should pin transformers to <=4.48 for now. Custom beam scoring metrics will be added in a future PR;
  2. The low_memory=True path [which runs the forward pass for each beam in sequence rather than in parallel] was not coded in the refactor. Since this flag was added, several memory-reduction techniques have been introduced, so I'm assuming it's niche now. It would also bloat the code significantly. An informative error message was added.

Note to reviewers

Don't look at the diff for beam search, but rather read the new _beam_search as if it was a new function. One of goals of the refactor is improving code readability for future hacking (beam search is a form of test-time compute), and thus the new code is heavily commented -- please add a review comment if a certain part of the code isn't immediately understandable, so we can improve it.

Benefits of the vectorization:

  1. 🚤 Much better scalability with respect to num_beams (see benchmarks below). Note that we can further improve these numbers -- beam_search should now be compatible with torch.compile (to be explored in a future PR);
  2. 🔍 Simpler to read and understand -- it's still quite complex, but now we don't have custom state variables nor a large number of nested calls;
  3. 🐛 Bugfix: Beam search now accepts arbitrary stopping criteria.

Benchmarks

Test script (small model, beam search, no compilation)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

DEVICE = "cuda"
MODEL_IDS = [
  "Qwen/Qwen2.5-0.5B-Instruct",
]
NUM_BEAMS = [2, 8, 32, 128, 512]
NUM_RUNS = 5
NUM_WARMUP_RUNS = 2
MAX_NEW_TOKENS = 128

generation_kwargs = {
  "do_sample": False,
  "max_new_tokens": MAX_NEW_TOKENS,
  "min_new_tokens": MAX_NEW_TOKENS,  # forces the generation of `max_new_tokens`
}
memory = {}  # Max memory, MB
throughput = {}  # Throughput, tokens/sec

for model_id in MODEL_IDS:
  tokenizer = AutoTokenizer.from_pretrained(model_id)
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)

  inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)

  memory[model_id] = {}
  throughput[model_id] = {}

  for num_beams in NUM_BEAMS:
      start_event = torch.cuda.Event(enable_timing=True)
      end_event = torch.cuda.Event(enable_timing=True)
      torch.cuda.reset_peak_memory_stats(DEVICE)
      torch.cuda.empty_cache()
      torch.cuda.synchronize()

      # warmup
      for _ in range(NUM_WARMUP_RUNS):
          _ = model.generate(**inputs, **generation_kwargs, num_beams=num_beams)

      # measure
      start_event.record()
      for _ in range(NUM_RUNS):
          _ = model.generate(**inputs, **generation_kwargs, num_beams=num_beams)
      end_event.record()

      torch.cuda.synchronize()
      max_memory = torch.cuda.max_memory_allocated(DEVICE) * 1e-6
      tok_per_sec = (NUM_RUNS * MAX_NEW_TOKENS) / (start_event.elapsed_time(end_event) * 1.0e-3)
      memory[model_id][num_beams] = max_memory
      throughput[model_id][num_beams] = tok_per_sec

      print("\nNum Beams: ", num_beams)
      print("Max memory (MB): ", max_memory)
      print("Throughput (tokens/sec): ", tok_per_sec)

  print("\nModel: ", model_id)
  print("Num Beams: ", NUM_BEAMS)
  print("Max memory (MB): ", memory[model_id])
  print("Throughput (tokens/sec): ", throughput[model_id])

Measured on an RTX4090

Throughput (tokens/sec):

num_beams old new ratio (new/old, higher throughput is better)
2 62.98 62.82 1.00
8 60.10 62.67 1.04
32 54.25 61.24 1.13
128 38.15 56.51 1.48
512 15.95 36.65 2.30

Memory (MB):

num_beams old new ratio (new/old, lower memory is better)
2 1017.32 1016.15 1.00
8 1050.93 1045.92 1.00
32 1206.99 1188.17 0.98
128 1717.07 1641.76 0.96
512 3848.44 3547.72 0.92

Tests run

checked box = no regressions vs main

  • RUN_SLOW=1 py.test tests/models/gpt2/test_modeling_gpt2.py -vv -- slow gpt2 tests (reference model for older implementations)
  • RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv -- slow llama tests (key model)
  • RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -vv -- slow whisper tests (key model, has hard beam search tests) [note: 13 failures on main]
  • RUN_SLOW=1 py.test tests/models/t5/test_modeling_t5.py -vv -- slow t5 tests (has hard beam search tests)
  • RUN_SLOW=1 py.test tests/models/bart/test_modeling_bart.py -vv -- slow bart tests (has hard beam search tests)
  • RUN_SLOW=1 py.test tests/models/rag/test_modeling_rag.py -vv -- slow rag tests (has hard beam search tests) [note: 4 failures on main]
  • RUN_SLOW=1 py.test tests/utils/test_cache_utils.py -vv -- slow cache tests (sanity check)

Follow-up

  • Deprecate BeamScorer
  • Deprecate other niche beam methods (e.g. constrained_beam_search) -- maybe we can expose the scoring method, and write the other beam methods as a scoring modification?
  • Remove documentation from deprecated methods and classes
  • explore torch.compile (probably requires isolating prefil first)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante mentioned this pull request Jan 29, 2025
@gante gante marked this pull request as ready for review January 30, 2025 10:01
@gante gante changed the title [generate] WIP vectorized beam search [generate] ✨ vectorized beam search ✨ Jan 30, 2025
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super cool 🔥 Thanks for providing benchmark with vectorized beams

Very nice idea to add comments, it is quite hard to follow the logic sometimes. I believe beam search will be tweaked mostly by advanced users, anyway. I still got lost trying to read the code hehe

Comment on lines 3592 to 3599
top_num_beam_mask = torch.cat(
(
torch.ones((num_beams), dtype=torch.bool),
torch.zeros((beams_to_keep - num_beams), dtype=torch.bool),
),
dim=0,
).to(next_token_hits_stopping_criteria.device)
did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part got me lost a bit. Aren't we choosing beam candidates that already have stopped with eos, but happened to be on indices at >num_beams? Since the mask zeroes out the second half after num_beams

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did_top_num_beams_just_finished selects beams that are valid to be considered as finished, i.e. in the top num_beams and that have hit some stopping criteria.

Beams in the latter part of the num_beams dimension are indeed discarded by this mask. But this mask is used to tag "finished beams", not the beams we will use in the next iteration of beam search. Finished beams and only come from the top num_beams beams.

Does this make sense? How can I make this part clearer in the comments?

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me new logic with separate functions made things a bit clear, in a way that each small function has a dosctring, explaining what we do. I think docstring are more verbose than the comments we had prev. But it is a bit annoying to keep going forth and back when reading.

If I had to choose, I would go with prev version with a note in mind for refactoring. AFAIK think most parts of generate will get separated out to functions, when we start refactoring. Decoding methods share a lot and imo we can decompose it.

That way people can also rewrite and plug their own functions. In current way, it doesn't seem to be easy to rewrite a small function (say for ex how sampling is done in get_top_k_continuations) and plug it in to beam search, without having to copy the whole decoding method. WDYT? Do we even need to offer users to override and plug small functions?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • let's move functions outside scope of functions themselves
  • let's keep what you did!
    The main arguments:
  1. re-usabliity: users can easily change a small part of a function
  2. more readable
  3. easier to debug and isolate changes to a single function!

🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh and

merged_sequences = torch.cat((sequences, topk_running_sequences), dim=1)
merged_scores = torch.cat((beam_scores, topk_log_probs), dim=1)
merged_beam_indices = torch.cat((beam_indices, topk_running_beam_indices), dim=1)
merged_is_sent_finished = torch.cat((is_sent_finished, did_top_num_beams_just_finished), dim=1)

is indeed better ( more readable and less slicing)
as you want, for 1% perf not super important

@gante
Copy link
Contributor Author

gante commented Feb 14, 2025

@zucchini-nlp @ArthurZucker PR comments addressed (namely, fns moved out _beam_search and re-introduced one set of torch.cat) 🤗

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

4 participants