diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 7839bf7813f2..cfc3c1c104d3 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -55,6 +55,28 @@ class TimmWrapperModelOutput(ModelOutput): attentions: Optional[tuple[torch.FloatTensor, ...]] = None +def _create_timm_model_with_error_handling(config: "TimmWrapperConfig", **model_kwargs): + """ + Creates a timm model and provides a clear error message if the model is not found, + suggesting a library update. + """ + try: + model = timm.create_model( + config.architecture, + pretrained=False, + **model_kwargs, + ) + return model + except RuntimeError as e: + if "Unknown model" in str(e): + # A good general check for unknown models. + raise ImportError( + f"The model architecture '{config.architecture}' is not supported in your version of timm ({timm.__version__}). " + "Please upgrade timm to a more recent version with `pip install -U timm`." + ) from e + raise e + + @auto_docstring class TimmWrapperPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" @@ -138,7 +160,7 @@ def __init__(self, config: TimmWrapperConfig): super().__init__(config) # using num_classes=0 to avoid creating classification head extra_init_kwargs = config.model_args or {} - self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0, **extra_init_kwargs) + self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs) self.post_init() @auto_docstring @@ -254,8 +276,8 @@ def __init__(self, config: TimmWrapperConfig): ) extra_init_kwargs = config.model_args or {} - self.timm_model = timm.create_model( - config.architecture, pretrained=False, num_classes=config.num_labels, **extra_init_kwargs + self.timm_model = _create_timm_model_with_error_handling( + config, num_classes=config.num_labels, **extra_init_kwargs ) self.num_labels = config.num_labels self.post_init()