Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def _load_state_dict_into_meta_model(
dtype=None,
load_in_8bit=False,
is_safetensors=False,
keep_in_fp32_modules=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand Down Expand Up @@ -611,7 +612,12 @@ def _load_state_dict_into_meta_model(
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
if keep_in_fp32_modules is not None and any(
module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules
):
param = param.to(torch.float32)
else:
param = param.to(dtype)

# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
if dtype is None:
Expand Down Expand Up @@ -1881,6 +1887,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
load_in_8bit_skip_modules (`List[str]`, *optional*):
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such
as Jukebox that has several heads in different places and not necessarily at the last position.
keep_in_fp32_modules (`List[str]`, *optional*):
An explicit list of the modules that we want to keep in full precision. This is somtimes needed to
retain the same performance as the full precision model when loading a model in half precision.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
Expand Down Expand Up @@ -1968,6 +1977,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
load_in_8bit = kwargs.pop("load_in_8bit", False)
load_in_8bit_threshold = kwargs.pop("load_in_8bit_threshold", 6.0)
load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None)
keep_in_fp32_modules = kwargs.pop("keep_in_fp32_modules", None)
Comment thread
younesbelkada marked this conversation as resolved.
Outdated
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

Expand All @@ -1982,6 +1992,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
elif not low_cpu_mem_usage:
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")

if keep_in_fp32_modules is not None and not low_cpu_mem_usage:
# Force `low_cpu_mem_usage` to be set to `True` - check the PR:
logger.warning(
"The argument `keep_in_fp32_modules` is used, force-enabling `low_cpu_mem_usage` to load the model"
)
low_cpu_mem_usage = True

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be force-set here.

@younesbelkada younesbelkada Dec 8, 2022

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proposed something in 115c0d0


if low_cpu_mem_usage:
# low_cpu_mem_usage requires PyTorch >= 1.9 to have the meta device.
require_version_core("torch>=1.9")
Expand Down Expand Up @@ -2309,6 +2326,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = load_in_8bit_skip_modules

if keep_in_fp32_modules is not None and isinstance(keep_in_fp32_modules, list):
modules_to_not_convert.extend(keep_in_fp32_modules)

model = replace_8bit_linear(
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
)
Expand Down Expand Up @@ -2415,6 +2436,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
load_in_8bit=load_in_8bit,
keep_in_fp32_modules=keep_in_fp32_modules,
)

model.is_loaded_in_8bit = load_in_8bit
Expand Down Expand Up @@ -2458,6 +2480,7 @@ def _load_pretrained_model(
offload_state_dict=None,
dtype=None,
load_in_8bit=False,
keep_in_fp32_modules=None,
):
is_safetensors = False
if load_in_8bit:
Expand Down Expand Up @@ -2534,11 +2557,25 @@ def _fix_key(key):
if key.startswith(prefix):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]

# upcast in fp32 if any
target_dtype = dtype
if keep_in_fp32_modules is not None and any(
module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also add a test of dtype being float16 here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in 8014c34

):
target_dtype = torch.float32

if param.device == torch.device("meta"):
if not load_in_8bit:
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype))
else:
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
set_module_8bit_tensor_to_device(
model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)
)
elif keep_in_fp32_modules is not None and state_dict is not None:
for key in state_dict:
if any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules):
state_dict[key] = state_dict[key].to(torch.float32)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not useful as with torch.load_state_dict, the weights are converted to the dtype inside the model. So it's the model dtype that you should fix here.

Also this removes the necessity for an Accelerate warning above, no?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Should be addressed in cb89c42


# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
Expand Down Expand Up @@ -2681,6 +2718,7 @@ def _find_mismatched_keys(
dtype=dtype,
load_in_8bit=load_in_8bit,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
else:
Expand Down
9 changes: 9 additions & 0 deletions tests/mixed_int8/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ def test_device_and_dtype_assignment(self):
# Check this does not throw an error
_ = self.model_fp16.float()

def test_fp32_int8_conversion(self):
r"""
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly.
"""
model = AutoModelForSeq2SeqLM.from_pretrained(
"t5-small", load_in_8bit=True, keep_in_fp32_modules=["wo"], device_map="auto"
)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)


class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):
Expand Down
39 changes: 38 additions & 1 deletion tests/models/t5/test_modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
import unittest

from transformers import T5Config, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.testing_utils import (
require_accelerate,
require_sentencepiece,
require_tokenizers,
require_torch,
slow,
torch_device,
)
from transformers.utils import cached_property

from ...generation.test_utils import GenerationTesterMixin
Expand Down Expand Up @@ -820,6 +827,36 @@ def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])


@require_torch
@require_accelerate
@require_tokenizers
class T5ModelFp16Tests(unittest.TestCase):
def test_fp16_fp32_conversion(self):
r"""
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
"""
# Load without using `accelerate`
model = T5ForConditionalGeneration.from_pretrained(
"t5-small", torch_dtype=torch.float16, keep_in_fp32_modules=["wo"]
)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)

# Load without using `accelerate`
model = T5ForConditionalGeneration.from_pretrained(
"t5-small", torch_dtype=torch.float16, keep_in_fp32_modules=["wo"], low_cpu_mem_usage=True
)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)

# Load using `accelerate`
model = T5ForConditionalGeneration.from_pretrained(
"t5-small", torch_dtype=torch.float16, keep_in_fp32_modules=["wo"], device_map="auto"
)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)


@require_torch
@require_sentencepiece
@require_tokenizers
Expand Down