Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,6 @@ def _get_dtype(
2. Else, use the dtype provided as a dict or str
"""
is_sharded = sharded_metadata is not None
asked_dtype = dtype

if dtype is not None:
if isinstance(dtype, str):
Expand Down Expand Up @@ -806,24 +805,21 @@ def _get_dtype(
if isinstance(dtype, dict):
main_dtype = dtype.get("", torch.get_default_dtype())
main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype

logger.warning_once(
"Using different dtypes per module is deprecated and will be removed in future versions "
"Setting different dtypes per backbone model might cause device errors downstream, therefore "
f"setting the dtype={main_dtype} for all modules."
)

else:
main_dtype = dtype

# Set it on the config and subconfigs
config.dtype = main_dtype
for sub_config_key in config.sub_configs:
if (sub_config := getattr(config, sub_config_key)) is not None:
# The dtype was "auto" -> try to read the subconfig dtype value if any
if asked_dtype == "auto":
sub_dtype = getattr(sub_config, "dtype", main_dtype)
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
# The dtype was provided as a dict, try to see if we match the subconfig name
elif isinstance(dtype, dict):
sub_dtype = dtype.get(sub_config_key, main_dtype)
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
else:
sub_dtype = main_dtype
sub_config.dtype = sub_dtype
sub_config.dtype = main_dtype

return config, main_dtype

Expand Down
24 changes: 13 additions & 11 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ def test_model_from_config_dtype_composite(self):
"""
Test that from_pretrained works with dtype being as a dict per each sub-config in composite config
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
Note, this is a deprecated feature and we fallback to main dtype in all cases below. This test checks
if the dtype fallback works correctly.
"""
# Load without dtype specified
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA)
Expand All @@ -587,42 +589,42 @@ def test_model_from_config_dtype_composite(self):
self.assertEqual(model.model.vision_tower.dtype, torch.float16)
self.assertIsInstance(model.config.dtype, torch.dtype)

# should be able to set dtype as a dict for each sub-config
# should be able to accept dtype as a dict for each sub-config
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
)
self.assertEqual(model.model.language_model.dtype, torch.float32)
self.assertEqual(model.model.vision_tower.dtype, torch.float16)
self.assertEqual(model.model.language_model.dtype, torch.bfloat16)
self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
self.assertIsInstance(model.config.dtype, torch.dtype)

# should be able to set the values as torch.dtype (not str)
# should be able to accept the values as torch.dtype (not str)
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
)
self.assertEqual(model.model.language_model.dtype, torch.float32)
self.assertEqual(model.model.vision_tower.dtype, torch.float16)
self.assertEqual(model.model.language_model.dtype, torch.bfloat16)
self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
self.assertIsInstance(model.config.dtype, torch.dtype)

# should be able to set the values in configs directly and pass it to `from_pretrained`
# should be able to accept the values in configs directly and pass it to `from_pretrained`
config = copy.deepcopy(model.config)
config.text_config.dtype = torch.float32
config.vision_config.dtype = torch.bfloat16
config.dtype = torch.float16
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, dtype="auto")
self.assertEqual(model.model.language_model.dtype, torch.float32)
self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.model.language_model.dtype, torch.float16)
self.assertEqual(model.model.vision_tower.dtype, torch.float16)
self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
self.assertIsInstance(model.config.dtype, torch.dtype)

# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, dtype="auto")
self.assertEqual(
model.model.language_model.dtype, torch.float32
model.model.language_model.dtype, torch.float16
) # remember config says float32 for text_config
self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.model.vision_tower.dtype, torch.float16)
self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
self.assertIsInstance(model.config.dtype, torch.dtype)

Expand Down