Skip to content

Commit a7755d2

Browse files
authored
Generate: unset GenerationConfig parameters do not raise warning (#29119)
1 parent 7d312ad commit a7755d2

File tree

6 files changed

+42
-24
lines changed

6 files changed

+42
-24
lines changed

src/transformers/generation/configuration_utils.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ class GenerationConfig(PushToHubMixin):
271271

272272
def __init__(self, **kwargs):
273273
# Parameters that control the length of the output
274-
# if the default `max_length` is updated here, make sure to update the `generate` tests following https://github.com/huggingface/transformers/pull/25030
275274
self.max_length = kwargs.pop("max_length", 20)
276275
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
277276
self.min_length = kwargs.pop("min_length", 0)
@@ -407,32 +406,34 @@ def validate(self, is_init=False):
407406
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
408407
+ fix_location
409408
)
410-
if self.temperature != 1.0:
409+
if self.temperature is not None and self.temperature != 1.0:
411410
warnings.warn(
412411
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
413412
UserWarning,
414413
)
415-
if self.top_p != 1.0:
414+
if self.top_p is not None and self.top_p != 1.0:
416415
warnings.warn(
417416
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
418417
UserWarning,
419418
)
420-
if self.typical_p != 1.0:
419+
if self.typical_p is not None and self.typical_p != 1.0:
421420
warnings.warn(
422421
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
423422
UserWarning,
424423
)
425-
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
424+
if (
425+
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
426+
): # contrastive search uses top_k
426427
warnings.warn(
427428
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
428429
UserWarning,
429430
)
430-
if self.epsilon_cutoff != 0.0:
431+
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
431432
warnings.warn(
432433
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
433434
UserWarning,
434435
)
435-
if self.eta_cutoff != 0.0:
436+
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
436437
warnings.warn(
437438
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
438439
UserWarning,
@@ -453,21 +454,21 @@ def validate(self, is_init=False):
453454
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
454455
UserWarning,
455456
)
456-
if self.num_beam_groups != 1:
457+
if self.num_beam_groups is not None and self.num_beam_groups != 1:
457458
warnings.warn(
458459
single_beam_wrong_parameter_msg.format(
459460
flag_name="num_beam_groups", flag_value=self.num_beam_groups
460461
),
461462
UserWarning,
462463
)
463-
if self.diversity_penalty != 0.0:
464+
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
464465
warnings.warn(
465466
single_beam_wrong_parameter_msg.format(
466467
flag_name="diversity_penalty", flag_value=self.diversity_penalty
467468
),
468469
UserWarning,
469470
)
470-
if self.length_penalty != 1.0:
471+
if self.length_penalty is not None and self.length_penalty != 1.0:
471472
warnings.warn(
472473
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
473474
UserWarning,
@@ -491,7 +492,7 @@ def validate(self, is_init=False):
491492
raise ValueError(
492493
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
493494
)
494-
if self.num_beam_groups != 1:
495+
if self.num_beam_groups is not None and self.num_beam_groups != 1:
495496
raise ValueError(
496497
constrained_wrong_parameter_msg.format(
497498
flag_name="num_beam_groups", flag_value=self.num_beam_groups
@@ -1000,6 +1001,9 @@ def update(self, **kwargs):
10001001
setattr(self, key, value)
10011002
to_remove.append(key)
10021003

1003-
# remove all the attributes that were updated, without modifying the input dict
1004+
# Confirm that the updated instance is still valid
1005+
self.validate()
1006+
1007+
# Remove all the attributes that were updated, without modifying the input dict
10041008
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
10051009
return unused_kwargs

src/transformers/generation/flax_utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,6 @@ def generate(
330330

331331
generation_config = copy.deepcopy(generation_config)
332332
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
333-
generation_config.validate()
334333
self._validate_model_kwargs(model_kwargs.copy())
335334

336335
logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()

src/transformers/generation/tf_utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,6 @@ def generate(
736736

737737
generation_config = copy.deepcopy(generation_config)
738738
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
739-
generation_config.validate()
740739
self._validate_model_kwargs(model_kwargs.copy())
741740

742741
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)

src/transformers/generation/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -1347,7 +1347,6 @@ def generate(
13471347

13481348
generation_config = copy.deepcopy(generation_config)
13491349
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
1350-
generation_config.validate()
13511350
self._validate_model_kwargs(model_kwargs.copy())
13521351

13531352
# 2. Set generation parameters if not already defined

src/transformers/utils/quantization_config.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def to_json_string(self, use_diff: bool = True) -> str:
152152
config_dict = self.to_dict()
153153
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
154154

155-
# Copied from transformers.generation.configuration_utils.GenerationConfig.update
156155
def update(self, **kwargs):
157156
"""
158157
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
@@ -171,7 +170,7 @@ def update(self, **kwargs):
171170
setattr(self, key, value)
172171
to_remove.append(key)
173172

174-
# remove all the attributes that were updated, without modifying the input dict
173+
# Remove all the attributes that were updated, without modifying the input dict
175174
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
176175
return unused_kwargs
177176

tests/generation/test_configuration_utils.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -124,26 +124,44 @@ def test_validate(self):
124124
"""
125125
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
126126
"""
127-
# Case 1: A correct configuration will not throw any warning
127+
# A correct configuration will not throw any warning
128128
with warnings.catch_warnings(record=True) as captured_warnings:
129129
GenerationConfig()
130130
self.assertEqual(len(captured_warnings), 0)
131131

132-
# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
132+
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
133133
# parameters with `do_sample=False`). May be escalated to an error in the future.
134134
with warnings.catch_warnings(record=True) as captured_warnings:
135-
GenerationConfig(temperature=0.5)
135+
GenerationConfig(do_sample=False, temperature=0.5)
136136
self.assertEqual(len(captured_warnings), 1)
137137

138-
# Case 3: Impossible sets of contraints/parameters will raise an exception
138+
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
139+
# that is done by unsetting the parameter (i.e. setting it to None)
140+
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
141+
with warnings.catch_warnings(record=True) as captured_warnings:
142+
# BAD - 0.9 means it is still set, we should warn
143+
generation_config_bad_temperature.update(temperature=0.9)
144+
self.assertEqual(len(captured_warnings), 1)
145+
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
146+
with warnings.catch_warnings(record=True) as captured_warnings:
147+
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
148+
generation_config_bad_temperature.update(temperature=1.0)
149+
self.assertEqual(len(captured_warnings), 0)
150+
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
151+
with warnings.catch_warnings(record=True) as captured_warnings:
152+
# OK - None means it is unset, nothing to warn about
153+
generation_config_bad_temperature.update(temperature=None)
154+
self.assertEqual(len(captured_warnings), 0)
155+
156+
# Impossible sets of contraints/parameters will raise an exception
139157
with self.assertRaises(ValueError):
140-
GenerationConfig(num_return_sequences=2)
158+
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
141159

142-
# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
160+
# Passing `generate()`-only flags to `validate` will raise an exception
143161
with self.assertRaises(ValueError):
144162
GenerationConfig(logits_processor="foo")
145163

146-
# Case 5: Model-specific parameters will NOT raise an exception or a warning
164+
# Model-specific parameters will NOT raise an exception or a warning
147165
with warnings.catch_warnings(record=True) as captured_warnings:
148166
GenerationConfig(foo="bar")
149167
self.assertEqual(len(captured_warnings), 0)

0 commit comments

Comments
 (0)