-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[GenerationConfig] add additional kwargs handling #21269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
260fb3b
2c70587
659e776
757dd0d
95dbc4b
71c7255
2a9ff00
a0a06a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -282,6 +282,16 @@ def __init__(self, **kwargs): | |
| self._commit_hash = kwargs.pop("_commit_hash", None) | ||
| self.transformers_version = kwargs.pop("transformers_version", __version__) | ||
|
|
||
| # Additional attributes without default values | ||
| if not self._from_model_config: | ||
| # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a model's default configuration file | ||
| for key, value in kwargs.items(): | ||
| try: | ||
| setattr(self, key, value) | ||
| except AttributeError as err: | ||
| logger.error(f"Can't set {key} with value {value} for {self}") | ||
| raise err | ||
|
|
||
| def __eq__(self, other): | ||
| self_dict = self.__dict__.copy() | ||
| other_dict = other.__dict__.copy() | ||
|
|
@@ -537,7 +547,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": | |
| if "_commit_hash" in kwargs and "_commit_hash" in config_dict: | ||
| kwargs["_commit_hash"] = config_dict["_commit_hash"] | ||
|
|
||
| config = cls(**config_dict) | ||
| # remove all the arguments that are in the config_dict | ||
|
|
||
| config = cls(**config_dict, **kwargs) | ||
| unused_kwargs = config.update(**kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line now only exists to obtain
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well yes, for example when |
||
|
|
||
| logger.info(f"Generate config {config}") | ||
|
|
@@ -546,6 +558,18 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": | |
| else: | ||
| return config | ||
|
|
||
| def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: | ||
| """ | ||
| Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, | ||
| converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* | ||
| string, which can then be stored in the json format. | ||
| """ | ||
| if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): | ||
| d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] | ||
| for value in d.values(): | ||
| if isinstance(value, dict): | ||
| self.dict_torch_dtype_to_str(value) | ||
|
|
||
| def to_diff_dict(self) -> Dict[str, Any]: | ||
| """ | ||
| Removes all attributes from config which correspond to the default config attributes for better readability and | ||
|
|
@@ -566,6 +590,7 @@ def to_diff_dict(self) -> Dict[str, Any]: | |
| if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]: | ||
| serializable_config_dict[key] = value | ||
|
|
||
| self.dict_torch_dtype_to_str(serializable_config_dict) | ||
| return serializable_config_dict | ||
|
|
||
| def to_dict(self) -> Dict[str, Any]: | ||
|
|
@@ -582,6 +607,7 @@ def to_dict(self) -> Dict[str, Any]: | |
| # Transformers version when serializing this file | ||
| output["transformers_version"] = __version__ | ||
|
|
||
| self.dict_torch_dtype_to_str(output) | ||
| return output | ||
|
|
||
| def to_json_string(self, use_diff: bool = True) -> str: | ||
|
|
@@ -630,7 +656,8 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" | |
| [`GenerationConfig`]: The configuration object instantiated from those parameters. | ||
| """ | ||
| config_dict = model_config.to_dict() | ||
| config = cls.from_dict(config_dict, return_unused_kwargs=False) | ||
| config_dict.pop("_from_model_config", None) | ||
| config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True) | ||
|
|
||
| # Special case: some models have generation attributes set in the decoder. Use them if still unset in the | ||
| # generation config. | ||
|
|
@@ -642,7 +669,6 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" | |
| if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr): | ||
| setattr(config, attr, decoder_config[attr]) | ||
|
|
||
| config._from_model_config = True | ||
| return config | ||
|
|
||
| def update(self, **kwargs): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add a comment here explaining why we need this
if, otherwise we may be like "wtf?" in the futureThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, thanks for the comment!