-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[core] Layerwise Upcasting #10347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[core] Layerwise Upcasting #10347
Changes from 27 commits
Commits
Show all changes
55 commits
Select commit
Hold shift + click to select a range
36b0c37
update
a-r-r-o-w 42046c0
update
a-r-r-o-w 7dc739b
make style
a-r-r-o-w 7ed7141
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w 1fa4ee5
remove dynamo disable
a-r-r-o-w da4907e
add coauthor
a-r-r-o-w bc2ada4
update
a-r-r-o-w 91bfc3d
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w 7c31bb0
update
a-r-r-o-w 8975bbf
update
a-r-r-o-w 341fbfc
update mixin
a-r-r-o-w 5f898a1
add some basic tests
a-r-r-o-w 558c64e
update
a-r-r-o-w 7858f2c
update
a-r-r-o-w 2663026
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w 3d84b9e
non_blocking
a-r-r-o-w 9372647
improvements
a-r-r-o-w a0f1de7
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w e586ef3
update
a-r-r-o-w cfe6318
norm.* -> norm
a-r-r-o-w 9235f77
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w 7627415
apply suggestions from review
a-r-r-o-w b9e1217
add example
a-r-r-o-w bde103c
update hook implementation to the latest changes from pyramid attenti…
a-r-r-o-w 64e6c9c
deinitialize should raise an error
a-r-r-o-w 7037133
update doc page
a-r-r-o-w f1b46d6
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w 390742b
Apply suggestions from code review
a-r-r-o-w 19901e7
update docs
a-r-r-o-w 3ae32b4
update
a-r-r-o-w bf797e7
refactor
a-r-r-o-w d22465a
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w 5956a9e
fix _always_upcast_modules for asym ae and vq_model
a-r-r-o-w 93bd8ee
fix lumina embedding forward to not depend on weight dtype
a-r-r-o-w 77a32a7
refactor tests
a-r-r-o-w 1335d7e
add simple lora inference tests
a-r-r-o-w a263e1a
_always_upcast_modules -> _precision_sensitive_module_patterns
a-r-r-o-w 93e36ba
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w 245137f
remove todo comments about review; revert changes to self.dtype in un…
a-r-r-o-w b713511
check layer dtypes in lora test
a-r-r-o-w 4450b1c
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w ed14d26
fix UNet1DModelTests::test_layerwise_upcasting_inference
a-r-r-o-w 2c9c33f
_precision_sensitive_module_patterns -> _skip_layerwise_casting_patte…
a-r-r-o-w 08211f7
skip test in NCSNppModelTests
a-r-r-o-w 59e04c3
skip tests for AutoencoderTinyTests
a-r-r-o-w 0a16826
skip tests for AutoencoderOobleckTests
a-r-r-o-w 1d306b8
skip tests for UNet1DModelTests - unsupported pytorch operations
a-r-r-o-w a9364bd
layerwise_upcasting -> layerwise_casting
a-r-r-o-w c4d5a2b
skip tests for UNetRLModelTests; needs next pytorch release for curre…
a-r-r-o-w d175d93
add layerwise fp8 pipeline test
a-r-r-o-w bf11691
use xfail
a-r-r-o-w 1c523b2
Apply suggestions from code review
a-r-r-o-w 7803364
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w 376adf9
add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32…
a-r-r-o-w 719e8d3
add note about memory consumption on tesla CI runner for failing test
a-r-r-o-w File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from ..utils import is_torch_available | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| from .layerwise_upcasting import apply_layerwise_upcasting, apply_layerwise_upcasting_hook |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| # Copyright 2024 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import functools | ||
| from typing import Any, Dict, Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from ..utils.logging import get_logger | ||
|
|
||
|
|
||
| logger = get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| class ModelHook: | ||
| r""" | ||
| A hook that contains callbacks to be executed just before and after the forward method of a model. | ||
| """ | ||
|
|
||
| _is_stateful = False | ||
|
|
||
| def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| r""" | ||
| Hook that is executed when a model is initialized. | ||
|
|
||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module attached to this hook. | ||
| """ | ||
| return module | ||
|
|
||
| def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| r""" | ||
| Hook that is executed when a model is deinitalized. | ||
|
|
||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module attached to this hook. | ||
| """ | ||
| module.forward = module._old_forward | ||
| del module._old_forward | ||
| return module | ||
|
|
||
| def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: | ||
| r""" | ||
| Hook that is executed just before the forward method of the model. | ||
|
|
||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module whose forward pass will be executed just after this event. | ||
| args (`Tuple[Any]`): | ||
| The positional arguments passed to the module. | ||
| kwargs (`Dict[Str, Any]`): | ||
| The keyword arguments passed to the module. | ||
| Returns: | ||
| `Tuple[Tuple[Any], Dict[Str, Any]]`: | ||
| A tuple with the treated `args` and `kwargs`. | ||
| """ | ||
| return args, kwargs | ||
|
|
||
| def post_forward(self, module: torch.nn.Module, output: Any) -> Any: | ||
| r""" | ||
| Hook that is executed just after the forward method of the model. | ||
|
|
||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module whose forward pass been executed just before this event. | ||
| output (`Any`): | ||
| The output of the module. | ||
| Returns: | ||
| `Any`: The processed `output`. | ||
| """ | ||
| return output | ||
|
|
||
| def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| r""" | ||
| Hook that is executed when the hook is detached from a module. | ||
|
|
||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module detached from this hook. | ||
| """ | ||
| return module | ||
|
|
||
| def reset_state(self, module: torch.nn.Module): | ||
| if self._is_stateful: | ||
| raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") | ||
| return module | ||
|
|
||
|
|
||
| class HookRegistry: | ||
| def __init__(self, module_ref: torch.nn.Module) -> None: | ||
| super().__init__() | ||
|
|
||
| self.hooks: Dict[str, ModelHook] = {} | ||
|
|
||
| self._module_ref = module_ref | ||
| self._hook_order = [] | ||
|
|
||
| def register_hook(self, hook: ModelHook, name: str) -> None: | ||
| if name in self.hooks.keys(): | ||
| logger.warning(f"Hook with name {name} already exists, replacing it.") | ||
|
|
||
| if hasattr(self._module_ref, "_old_forward"): | ||
| old_forward = self._module_ref._old_forward | ||
| else: | ||
| old_forward = self._module_ref.forward | ||
| self._module_ref._old_forward = self._module_ref.forward | ||
|
|
||
| self._module_ref = hook.initialize_hook(self._module_ref) | ||
|
|
||
| if hasattr(hook, "new_forward"): | ||
| rewritten_forward = hook.new_forward | ||
|
|
||
| def new_forward(module, *args, **kwargs): | ||
| args, kwargs = hook.pre_forward(module, *args, **kwargs) | ||
| output = rewritten_forward(module, *args, **kwargs) | ||
| return hook.post_forward(module, output) | ||
| else: | ||
|
|
||
| def new_forward(module, *args, **kwargs): | ||
| args, kwargs = hook.pre_forward(module, *args, **kwargs) | ||
| output = old_forward(*args, **kwargs) | ||
| return hook.post_forward(module, output) | ||
|
|
||
| self._module_ref.forward = functools.update_wrapper( | ||
| functools.partial(new_forward, self._module_ref), old_forward | ||
| ) | ||
|
|
||
| self.hooks[name] = hook | ||
| self._hook_order.append(name) | ||
|
|
||
| def get_hook(self, name: str) -> Optional[ModelHook]: | ||
| if name not in self.hooks.keys(): | ||
| return None | ||
| return self.hooks[name] | ||
|
|
||
| def remove_hook(self, name: str, recurse: bool = True) -> None: | ||
| if name in self.hooks.keys(): | ||
| hook = self.hooks[name] | ||
| self._module_ref = hook.deinitalize_hook(self._module_ref) | ||
| del self.hooks[name] | ||
| self._hook_order.remove(name) | ||
|
|
||
| if recurse: | ||
| for module_name, module in self._module_ref.named_modules(): | ||
| if module_name == "": | ||
| continue | ||
| if hasattr(module, "_diffusers_hook"): | ||
| module._diffusers_hook.remove_hook(name, recurse=False) | ||
|
|
||
| def reset_stateful_hooks(self, recurse: bool = True) -> None: | ||
| for hook_name in self._hook_order: | ||
| hook = self.hooks[hook_name] | ||
| if hook._is_stateful: | ||
| hook.reset_state(self._module_ref) | ||
|
|
||
| if recurse: | ||
| for module_name, module in self._module_ref.named_modules(): | ||
| if module_name == "": | ||
| continue | ||
| if hasattr(module, "_diffusers_hook"): | ||
| module._diffusers_hook.reset_stateful_hooks(recurse=False) | ||
|
|
||
| @classmethod | ||
| def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": | ||
| if not hasattr(module, "_diffusers_hook"): | ||
| module._diffusers_hook = cls(module) | ||
| return module._diffusers_hook | ||
|
|
||
| def __repr__(self) -> str: | ||
| hook_repr = "" | ||
| for i, hook_name in enumerate(self._hook_order): | ||
| hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" | ||
| if i < len(self._hook_order) - 1: | ||
| hook_repr += "\n" | ||
| return f"HookRegistry(\n{hook_repr}\n)" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.