diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7fb1c62dae85..5ddea3e34e19 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -25,7 +25,7 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass -from functools import partial +from functools import partial, wraps from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -1912,6 +1912,7 @@ def get_memory_footprint(self, return_buffers=True): mem = mem + mem_bufs return mem + @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): # Checks if the model has been loaded in 8-bit if getattr(self, "is_quantized", False): @@ -1922,6 +1923,7 @@ def cuda(self, *args, **kwargs): else: return super().cuda(*args, **kwargs) + @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): # Checks if the model has been loaded in 8-bit if getattr(self, "is_quantized", False):