Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ class PretrainedConfig(PushToHubMixin):
"""

model_type: str = ""
base_config_key: str = ""
sub_configs: Dict[str, "PretrainedConfig"] = {}
is_composition: bool = False
attribute_map: Dict[str, str] = {}
_auto_class: Optional[str] = None
Expand Down Expand Up @@ -543,11 +545,22 @@ def from_pretrained(
cls._set_token_in_kwargs(kwargs, token)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if cls.base_config_key and cls.base_config_key in config_dict:
config_dict = config_dict[cls.base_config_key]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
# sometimes the config has no `base_config_key` if the config is used in several composite models
# e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
for k, v in config_dict.items():
if isinstance(v, dict) and v.get("model_type") == cls.model_type:
config_dict = v

# raise warning only if we still can't see a match in `model_type`
if config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)

Expand Down
17 changes: 8 additions & 9 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,15 +1597,14 @@ def _autoset_attn_implementation(
# Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict.
# Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)`
# If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238
for key in config:
if isinstance(getattr(config, key), PretrainedConfig):
sub_config = getattr(config, key)
curr_attn_implementation = (
requested_attn_implementation
if not isinstance(requested_attn_implementation, dict)
else requested_attn_implementation.get(key, None)
)
sub_config._attn_implementation_internal = curr_attn_implementation
for key in config.sub_configs.keys():
sub_config = getattr(config, key)
curr_attn_implementation = (
requested_attn_implementation
if not isinstance(requested_attn_implementation, dict)
else requested_attn_implementation.get(key, None)
)
sub_config._attn_implementation_internal = curr_attn_implementation
Comment on lines +1600 to +1607
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vey nice!


if use_flash_attention_2:
logger.warning_once(
Expand Down
42 changes: 4 additions & 38 deletions src/transformers/models/align/configuration_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# limitations under the License.
"""ALIGN model configuration"""

import os
from typing import TYPE_CHECKING, List, Union
from typing import TYPE_CHECKING, List


if TYPE_CHECKING:
Expand Down Expand Up @@ -95,6 +94,7 @@ class AlignTextConfig(PretrainedConfig):
```"""

model_type = "align_text_model"
base_config_key = "text_config"

def __init__(
self,
Expand Down Expand Up @@ -133,24 +133,6 @@ def __init__(
self.use_cache = use_cache
self.pad_token_id = pad_token_id

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

# get the text config dict if we are loading from AlignConfig
if config_dict.get("model_type") == "align":
config_dict = config_dict["text_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)


class AlignVisionConfig(PretrainedConfig):
r"""
Expand Down Expand Up @@ -223,6 +205,7 @@ class AlignVisionConfig(PretrainedConfig):
```"""

model_type = "align_vision_model"
base_config_key = "vision_config"

def __init__(
self,
Expand Down Expand Up @@ -272,24 +255,6 @@ def __init__(
self.drop_connect_rate = drop_connect_rate
self.num_hidden_layers = sum(num_block_repeats) * 4

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

# get the vision config dict if we are loading from AlignConfig
if config_dict.get("model_type") == "align":
config_dict = config_dict["vision_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)


class AlignConfig(PretrainedConfig):
r"""
Expand Down Expand Up @@ -340,6 +305,7 @@ class AlignConfig(PretrainedConfig):
```"""

model_type = "align"
sub_configs = {"text_config": AlignTextConfig, "vision_config": AlignVisionConfig}

def __init__(
self,
Expand Down
23 changes: 2 additions & 21 deletions src/transformers/models/altclip/configuration_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
# limitations under the License.
"""AltCLIP model configuration"""

import os
from typing import Union

from ...configuration_utils import PretrainedConfig
from ...utils import logging

Expand Down Expand Up @@ -199,6 +196,7 @@ class AltCLIPVisionConfig(PretrainedConfig):
```"""

model_type = "altclip_vision_model"
base_config_key = "vision_config"

def __init__(
self,
Expand Down Expand Up @@ -233,24 +231,6 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

# get the vision config dict if we are loading from AltCLIPConfig
if config_dict.get("model_type") == "altclip":
config_dict = config_dict["vision_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)


class AltCLIPConfig(PretrainedConfig):
r"""
Expand Down Expand Up @@ -298,6 +278,7 @@ class AltCLIPConfig(PretrainedConfig):
```"""

model_type = "altclip"
sub_configs = {"text_config": AltCLIPTextConfig, "vision_config": AltCLIPVisionConfig}

def __init__(
self, text_config=None, vision_config=None, projection_dim=768, logit_scale_init_value=2.6592, **kwargs
Expand Down
47 changes: 11 additions & 36 deletions src/transformers/models/bark/configuration_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# limitations under the License.
"""BARK model configuration"""

import os
from typing import Dict, Optional, Union
from typing import Dict

from ...configuration_utils import PretrainedConfig
from ...utils import add_start_docstrings, logging
from ..auto import CONFIG_MAPPING
from ..auto import CONFIG_MAPPING, AutoConfig


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -64,7 +63,6 @@


class BarkSubModelConfig(PretrainedConfig):
model_type = "bark_module"
keys_to_ignore_at_inference = ["past_key_values"]

attribute_map = {
Expand Down Expand Up @@ -101,38 +99,6 @@ def __init__(

super().__init__(**kwargs)

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
) -> "PretrainedConfig":
kwargs["cache_dir"] = cache_dir
kwargs["force_download"] = force_download
kwargs["local_files_only"] = local_files_only
kwargs["revision"] = revision

cls._set_token_in_kwargs(kwargs, token)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

# get the config dict if we are loading from Bark
if config_dict.get("model_type") == "bark":
config_dict = config_dict[f"{cls.model_type}_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)


@add_start_docstrings(
BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkSemanticConfig", model="BarkSemanticModel"),
Expand All @@ -154,6 +120,7 @@ def from_pretrained(
)
class BarkSemanticConfig(BarkSubModelConfig):
model_type = "semantic"
base_config_key = "semantic_config"


@add_start_docstrings(
Expand All @@ -176,6 +143,7 @@ class BarkSemanticConfig(BarkSubModelConfig):
)
class BarkCoarseConfig(BarkSubModelConfig):
model_type = "coarse_acoustics"
base_config_key = "coarse_acoustics_config"


@add_start_docstrings(
Expand Down Expand Up @@ -203,6 +171,7 @@ class BarkCoarseConfig(BarkSubModelConfig):
)
class BarkFineConfig(BarkSubModelConfig):
model_type = "fine_acoustics"
base_config_key = "fine_acoustics_config"

def __init__(self, tie_word_embeddings=True, n_codes_total=8, n_codes_given=1, **kwargs):
self.n_codes_total = n_codes_total
Expand Down Expand Up @@ -265,6 +234,12 @@ class BarkConfig(PretrainedConfig):
"""

model_type = "bark"
sub_configs = {
"semantic_config": BarkSemanticConfig,
"coarse_acoustics_config": BarkCoarseConfig,
"fine_acoustics_config": BarkFineConfig,
"codec_config": AutoConfig,
}

def __init__(
self,
Expand Down
42 changes: 3 additions & 39 deletions src/transformers/models/blip/configuration_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
# limitations under the License.
"""Blip model configuration"""

import os
from typing import Union

from ...configuration_utils import PretrainedConfig
from ...utils import logging

Expand Down Expand Up @@ -96,6 +93,7 @@ class BlipTextConfig(PretrainedConfig):
```"""

model_type = "blip_text_model"
base_config_key = "text_config"

def __init__(
self,
Expand Down Expand Up @@ -146,24 +144,6 @@ def __init__(
self.use_cache = use_cache
self.label_smoothing = label_smoothing

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

# get the text config dict if we are loading from BlipConfig
if config_dict.get("model_type") == "blip":
config_dict = config_dict["text_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)


class BlipVisionConfig(PretrainedConfig):
r"""
Expand Down Expand Up @@ -215,6 +195,7 @@ class BlipVisionConfig(PretrainedConfig):
```"""

model_type = "blip_vision_model"
base_config_key = "vision_config"

def __init__(
self,
Expand Down Expand Up @@ -245,24 +226,6 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

# get the vision config dict if we are loading from BlipConfig
if config_dict.get("model_type") == "blip":
config_dict = config_dict["vision_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)


class BlipConfig(PretrainedConfig):
r"""
Expand Down Expand Up @@ -316,6 +279,7 @@ class BlipConfig(PretrainedConfig):
```"""

model_type = "blip"
sub_configs = {"text_config": BlipTextConfig, "vision_config": BlipVisionConfig}

def __init__(
self,
Expand Down
Loading