diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4bd7e4f649e8..5fd8ff53b2a9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1454,6 +1454,12 @@ def _from_config(cls, config, **kwargs): if isinstance(dtype, str): dtype = getattr(torch, dtype) + # Set the same `dtype` on all subconfigs to avoid dtype mismatch. When "auto" dtype + # with nested models, we can't dispatch different dtype per backbone module + for sub_config_key in config.sub_configs: + if (sub_config := getattr(config, sub_config_key)) is not None: + sub_config.dtype = dtype + # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs) if "attn_implementation" in kwargs: config._attn_implementation = kwargs.pop("attn_implementation") diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 659563fa6d2f..3ce7f8032e1d 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -652,6 +652,15 @@ def test_model_from_config_dtype_composite(self): TINY_LLAVA, dtype={"text_config": "float32", "vision_config": "int64", "": "float16"} ) + # Check that `from_config` also works and uses the same dtype for all modules + config = AutoConfig.from_pretrained(TINY_LLAVA) + config.text_config.dtype = torch.float16 + config.dtype = torch.float32 + model = LlavaForConditionalGeneration._from_config(config) + self.assertEqual(model.model.language_model.dtype, torch.float32) + self.assertEqual(model.model.vision_tower.dtype, torch.float32) + self.assertEqual(model.dtype, torch.float32) + def test_model_from_pretrained_dtype(self): # test that the model can be instantiated with dtype of either # 1. explicit from_pretrained's dtype argument