-
Notifications
You must be signed in to change notification settings - Fork 33.6k
Add keep_in_fp32_modules support
#20683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
15fd00a
2bd00b3
115c0d0
a62a594
9b0688f
e3498da
c688e34
8014c34
966cc06
0f75387
243e6b5
1e80f14
73743b6
cb89c42
7d47df2
1d21843
50524ad
986730b
ef56114
703c7f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -561,6 +561,7 @@ def _load_state_dict_into_meta_model( | |
| dtype=None, | ||
| load_in_8bit=False, | ||
| is_safetensors=False, | ||
| keep_in_fp32_modules=None, | ||
| ): | ||
| """ | ||
| This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its | ||
|
|
@@ -611,7 +612,12 @@ def _load_state_dict_into_meta_model( | |
| # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params | ||
| # in int/uint/bool and not cast them. | ||
| if dtype is not None and torch.is_floating_point(param): | ||
| param = param.to(dtype) | ||
| if keep_in_fp32_modules is not None and any( | ||
| module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules | ||
| ): | ||
| param = param.to(torch.float32) | ||
| else: | ||
| param = param.to(dtype) | ||
|
|
||
| # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model | ||
| if dtype is None: | ||
|
|
@@ -1881,6 +1887,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| load_in_8bit_skip_modules (`List[str]`, *optional*): | ||
| An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such | ||
| as Jukebox that has several heads in different places and not necessarily at the last position. | ||
| keep_in_fp32_modules (`List[str]`, *optional*): | ||
| An explicit list of the modules that we want to keep in full precision. This is somtimes needed to | ||
| retain the same performance as the full precision model when loading a model in half precision. | ||
| subfolder (`str`, *optional*, defaults to `""`): | ||
| In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can | ||
| specify the folder name here. | ||
|
|
@@ -1968,6 +1977,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| load_in_8bit = kwargs.pop("load_in_8bit", False) | ||
| load_in_8bit_threshold = kwargs.pop("load_in_8bit_threshold", 6.0) | ||
| load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None) | ||
| keep_in_fp32_modules = kwargs.pop("keep_in_fp32_modules", None) | ||
| subfolder = kwargs.pop("subfolder", "") | ||
| commit_hash = kwargs.pop("_commit_hash", None) | ||
|
|
||
|
|
@@ -1982,6 +1992,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| elif not low_cpu_mem_usage: | ||
| raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") | ||
|
|
||
| if keep_in_fp32_modules is not None and not low_cpu_mem_usage: | ||
| # Force `low_cpu_mem_usage` to be set to `True` - check the PR: | ||
| logger.warning( | ||
| "The argument `keep_in_fp32_modules` is used, force-enabling `low_cpu_mem_usage` to load the model" | ||
| ) | ||
| low_cpu_mem_usage = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't be force-set here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. proposed something in 115c0d0 |
||
|
|
||
| if low_cpu_mem_usage: | ||
| # low_cpu_mem_usage requires PyTorch >= 1.9 to have the meta device. | ||
| require_version_core("torch>=1.9") | ||
|
|
@@ -2309,6 +2326,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| modules_to_not_convert = get_keys_to_not_convert(model) | ||
| else: | ||
| modules_to_not_convert = load_in_8bit_skip_modules | ||
|
|
||
| if keep_in_fp32_modules is not None and isinstance(keep_in_fp32_modules, list): | ||
| modules_to_not_convert.extend(keep_in_fp32_modules) | ||
|
|
||
| model = replace_8bit_linear( | ||
| model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert | ||
| ) | ||
|
|
@@ -2415,6 +2436,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| offload_state_dict=offload_state_dict, | ||
| dtype=torch_dtype, | ||
| load_in_8bit=load_in_8bit, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| ) | ||
|
|
||
| model.is_loaded_in_8bit = load_in_8bit | ||
|
|
@@ -2458,6 +2480,7 @@ def _load_pretrained_model( | |
| offload_state_dict=None, | ||
| dtype=None, | ||
| load_in_8bit=False, | ||
| keep_in_fp32_modules=None, | ||
| ): | ||
| is_safetensors = False | ||
| if load_in_8bit: | ||
|
|
@@ -2534,11 +2557,25 @@ def _fix_key(key): | |
| if key.startswith(prefix): | ||
| key = ".".join(key.split(".")[1:]) | ||
| param = model_state_dict[key] | ||
|
|
||
| # upcast in fp32 if any | ||
| target_dtype = dtype | ||
| if keep_in_fp32_modules is not None and any( | ||
| module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should also add a test of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added in 8014c34 |
||
| ): | ||
| target_dtype = torch.float32 | ||
|
|
||
| if param.device == torch.device("meta"): | ||
| if not load_in_8bit: | ||
| set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype)) | ||
| set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)) | ||
| else: | ||
| set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype)) | ||
| set_module_8bit_tensor_to_device( | ||
| model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype) | ||
| ) | ||
| elif keep_in_fp32_modules is not None and state_dict is not None: | ||
| for key in state_dict: | ||
| if any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules): | ||
| state_dict[key] = state_dict[key].to(torch.float32) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not useful as with Also this removes the necessity for an Accelerate warning above, no?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! Should be addressed in cb89c42 |
||
|
|
||
| # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. | ||
| if _fast_init: | ||
|
|
@@ -2681,6 +2718,7 @@ def _find_mismatched_keys( | |
| dtype=dtype, | ||
| load_in_8bit=load_in_8bit, | ||
| is_safetensors=is_safetensors, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| ) | ||
| error_msgs += new_error_msgs | ||
| else: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.