-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Fix kwargs handling in generate_with_fallback
#29225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix kwargs handling in generate_with_fallback
#29225
Conversation
ylacombe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @cifkao, thanks for the great work here, it's a nice catch.
The fix seems okay to me, I don't think we have a way to test if it does work, otherwise I'd have ask you that!
@sanchit-gandhi could we have your review here as well ?
| generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1 | ||
|
|
||
| generate_kwargs = dict(kwargs) | ||
| for key in ["do_sample", "temperature", "num_beams"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
temperature shouldn't be in kwargs as it's already an argument of .generate here right ?
It seems okay to check for do_sample and num_beams here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanted to be extra cautious here and make sure everything is safe locally, rather than relying on what gets passed down from 2 call frames up the stack. But I can remove temperature if you prefer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me as is - there's a preference for more explicit handling of kwargs than more buried ones
| do_condition_on_prev_tokens, | ||
| kwargs, | ||
| ): | ||
| kwargs = dict(kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you use dict(...) here and below? Is it to copy ? If yes, shouldn't we use copy.deepcopy instead ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's just to make a copy. My thinking here was that a shallow copy (using dict() or copy.copy()) has the same effect as using the **kwargs syntax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy.deepcopy should make the trick then right ? I'm just afraid that using dict might no be self-explanatory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't think we want to make a deep copy (what if kwargs contains a large object like assistant_model, for example?). So I changed the dict to copy.copy, which is equivalent and more readable.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @cifkao! Sounds good, just want to make sure you have a reproducer!
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great issue and super clear PR description @cifkao! The PR looks good to me. My only request is that we add a test to confirm beam search is working as expected. Could we modify your reproducer to do this, possibly with something like the following?
import datasets
from transformers import AutoProcessor, GenerationMixin, WhisperForConditionalGeneration
import numpy as np
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
orig_generate = GenerationMixin.generate
NUM_BEAMS = 2
def generate(self, *args, **kwargs):
assert args[1].num_beams == NUM_BEAMS
return orig_generate(self, *args, **kwargs)
GenerationMixin.generate = generate
ds = datasets.load_dataset(
"google/fleurs", "en_us", split="test", trust_remote_code=True
)
ds = ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16000))
raw_audio = np.concatenate([x["array"].astype(np.float32) for x in ds[:16]["audio"]])
inputs = processor(
[raw_audio],
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=16_000,
)
model.generate(
**inputs,
num_beams=NUM_BEAMS,
task="transcribe",
language="en",
)|
@sanchit-gandhi Test added! On After fix: |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating and adding a test!
What does this PR do?
Fixes #29312.
pop()toget()to avoid modifyingkwargsbetween loop iterations,kwargsis made as the first step ingenerate_with_fallback()to prevent any changes to it from propagating outside the method call.generation_configare removed from the keyword arguments tosuper().generate()(to avoid overriding the former), but this is done in a copy ofkwargsthat is not reused between iterations.Before submitting
Who can review?
@patrickvonplaten @sanchit-gandhi @ylacombe