@@ -271,7 +271,6 @@ class GenerationConfig(PushToHubMixin):
271
271
272
272
def __init__ (self , ** kwargs ):
273
273
# Parameters that control the length of the output
274
- # if the default `max_length` is updated here, make sure to update the `generate` tests following https://github.com/huggingface/transformers/pull/25030
275
274
self .max_length = kwargs .pop ("max_length" , 20 )
276
275
self .max_new_tokens = kwargs .pop ("max_new_tokens" , None )
277
276
self .min_length = kwargs .pop ("min_length" , 0 )
@@ -407,32 +406,34 @@ def validate(self, is_init=False):
407
406
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
408
407
+ fix_location
409
408
)
410
- if self .temperature != 1.0 :
409
+ if self .temperature is not None and self . temperature != 1.0 :
411
410
warnings .warn (
412
411
greedy_wrong_parameter_msg .format (flag_name = "temperature" , flag_value = self .temperature ),
413
412
UserWarning ,
414
413
)
415
- if self .top_p != 1.0 :
414
+ if self .top_p is not None and self . top_p != 1.0 :
416
415
warnings .warn (
417
416
greedy_wrong_parameter_msg .format (flag_name = "top_p" , flag_value = self .top_p ),
418
417
UserWarning ,
419
418
)
420
- if self .typical_p != 1.0 :
419
+ if self .typical_p is not None and self . typical_p != 1.0 :
421
420
warnings .warn (
422
421
greedy_wrong_parameter_msg .format (flag_name = "typical_p" , flag_value = self .typical_p ),
423
422
UserWarning ,
424
423
)
425
- if self .top_k != 50 and self .penalty_alpha is None : # contrastive search uses top_k
424
+ if (
425
+ self .top_k is not None and self .top_k != 50 and self .penalty_alpha is None
426
+ ): # contrastive search uses top_k
426
427
warnings .warn (
427
428
greedy_wrong_parameter_msg .format (flag_name = "top_k" , flag_value = self .top_k ),
428
429
UserWarning ,
429
430
)
430
- if self .epsilon_cutoff != 0.0 :
431
+ if self .epsilon_cutoff is not None and self . epsilon_cutoff != 0.0 :
431
432
warnings .warn (
432
433
greedy_wrong_parameter_msg .format (flag_name = "epsilon_cutoff" , flag_value = self .epsilon_cutoff ),
433
434
UserWarning ,
434
435
)
435
- if self .eta_cutoff != 0.0 :
436
+ if self .eta_cutoff is not None and self . eta_cutoff != 0.0 :
436
437
warnings .warn (
437
438
greedy_wrong_parameter_msg .format (flag_name = "eta_cutoff" , flag_value = self .eta_cutoff ),
438
439
UserWarning ,
@@ -453,21 +454,21 @@ def validate(self, is_init=False):
453
454
single_beam_wrong_parameter_msg .format (flag_name = "early_stopping" , flag_value = self .early_stopping ),
454
455
UserWarning ,
455
456
)
456
- if self .num_beam_groups != 1 :
457
+ if self .num_beam_groups is not None and self . num_beam_groups != 1 :
457
458
warnings .warn (
458
459
single_beam_wrong_parameter_msg .format (
459
460
flag_name = "num_beam_groups" , flag_value = self .num_beam_groups
460
461
),
461
462
UserWarning ,
462
463
)
463
- if self .diversity_penalty != 0.0 :
464
+ if self .diversity_penalty is not None and self . diversity_penalty != 0.0 :
464
465
warnings .warn (
465
466
single_beam_wrong_parameter_msg .format (
466
467
flag_name = "diversity_penalty" , flag_value = self .diversity_penalty
467
468
),
468
469
UserWarning ,
469
470
)
470
- if self .length_penalty != 1.0 :
471
+ if self .length_penalty is not None and self . length_penalty != 1.0 :
471
472
warnings .warn (
472
473
single_beam_wrong_parameter_msg .format (flag_name = "length_penalty" , flag_value = self .length_penalty ),
473
474
UserWarning ,
@@ -491,7 +492,7 @@ def validate(self, is_init=False):
491
492
raise ValueError (
492
493
constrained_wrong_parameter_msg .format (flag_name = "do_sample" , flag_value = self .do_sample )
493
494
)
494
- if self .num_beam_groups != 1 :
495
+ if self .num_beam_groups is not None and self . num_beam_groups != 1 :
495
496
raise ValueError (
496
497
constrained_wrong_parameter_msg .format (
497
498
flag_name = "num_beam_groups" , flag_value = self .num_beam_groups
@@ -1000,6 +1001,9 @@ def update(self, **kwargs):
1000
1001
setattr (self , key , value )
1001
1002
to_remove .append (key )
1002
1003
1003
- # remove all the attributes that were updated, without modifying the input dict
1004
+ # Confirm that the updated instance is still valid
1005
+ self .validate ()
1006
+
1007
+ # Remove all the attributes that were updated, without modifying the input dict
1004
1008
unused_kwargs = {key : value for key , value in kwargs .items () if key not in to_remove }
1005
1009
return unused_kwargs
0 commit comments