diff --git a/src/transformers/benchmark/benchmark_args.py b/src/transformers/benchmark/benchmark_args.py index 2d759ac34256..26c0eb95a4bc 100644 --- a/src/transformers/benchmark/benchmark_args.py +++ b/src/transformers/benchmark/benchmark_args.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from typing import Tuple -from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, torch_required +from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends from .benchmark_args_utils import BenchmarkArguments @@ -76,8 +76,8 @@ def __init__(self, **kwargs): ) @cached_property - @torch_required def _setup_devices(self) -> Tuple["torch.device", int]: + requires_backends(self, ["torch"]) logger.info("PyTorch: setting up devices") if not self.cuda: device = torch.device("cpu") @@ -95,19 +95,19 @@ def is_tpu(self): return is_torch_tpu_available() and self.tpu @property - @torch_required def device_idx(self) -> int: + requires_backends(self, ["torch"]) # TODO(PVP): currently only single GPU is supported return torch.cuda.current_device() @property - @torch_required def device(self) -> "torch.device": + requires_backends(self, ["torch"]) return self._setup_devices[0] @property - @torch_required def n_gpu(self): + requires_backends(self, ["torch"]) return self._setup_devices[1] @property diff --git a/src/transformers/benchmark/benchmark_args_tf.py b/src/transformers/benchmark/benchmark_args_tf.py index 8f3a9cea9465..12cb6f5cbbeb 100644 --- a/src/transformers/benchmark/benchmark_args_tf.py +++ b/src/transformers/benchmark/benchmark_args_tf.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from typing import Tuple -from ..utils import cached_property, is_tf_available, logging, tf_required +from ..utils import cached_property, is_tf_available, logging, requires_backends from .benchmark_args_utils import BenchmarkArguments @@ -77,8 +77,8 @@ def __init__(self, **kwargs): ) @cached_property - @tf_required def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]: + requires_backends(self, ["tf"]) tpu = None if self.tpu: try: @@ -91,8 +91,8 @@ def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver return tpu @cached_property - @tf_required def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]: + requires_backends(self, ["tf"]) if self.is_tpu: tf.config.experimental_connect_to_cluster(self._setup_tpu) tf.tpu.experimental.initialize_tpu_system(self._setup_tpu) @@ -111,23 +111,23 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.clus return strategy @property - @tf_required def is_tpu(self) -> bool: + requires_backends(self, ["tf"]) return self._setup_tpu is not None @property - @tf_required def strategy(self) -> "tf.distribute.Strategy": + requires_backends(self, ["tf"]) return self._setup_strategy @property - @tf_required def gpu_list(self): + requires_backends(self, ["tf"]) return tf.config.list_physical_devices("GPU") @property - @tf_required def n_gpu(self) -> int: + requires_backends(self, ["tf"]) if self.cuda: return len(self.gpu_list) return 0 diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 30d97649a9fd..ff8fa009935f 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -42,7 +42,7 @@ is_torch_device, is_torch_dtype, logging, - torch_required, + requires_backends, ) @@ -175,7 +175,6 @@ def as_tensor(value): return self - @torch_required def to(self, *args, **kwargs) -> "BatchFeature": """ Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in @@ -190,6 +189,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": Returns: [`BatchFeature`]: The same instance after modification. """ + requires_backends(self, ["torch"]) import torch # noqa new_data = {} diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 23219a328b6b..f5d404f657bc 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -127,10 +127,8 @@ is_vision_available, replace_return_docstrings, requires_backends, - tf_required, to_numpy, to_py_obj, torch_only_method, - torch_required, torch_version, ) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 011edfa1e77f..7b109ff02c42 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -56,8 +56,8 @@ is_torch_device, is_torch_tensor, logging, + requires_backends, to_py_obj, - torch_required, ) @@ -739,7 +739,6 @@ def convert_to_tensors( return self - @torch_required def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": """ Send all values to device by calling `v.to(device)` (PyTorch only). @@ -750,6 +749,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": Returns: [`BatchEncoding`]: The same instance after modification. """ + requires_backends(self, ["torch"]) # This check catches things like APEX blindly calling "to" on all inputs to a module # Otherwise it passes the casts down and casts the LongTensor containing the token idxs diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b92afac17125..e64bc977177e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -50,7 +50,6 @@ is_torch_tpu_available, logging, requires_backends, - torch_required, ) @@ -1386,8 +1385,8 @@ def ddp_timeout_delta(self) -> timedelta: return timedelta(seconds=self.ddp_timeout) @cached_property - @torch_required def _setup_devices(self) -> "torch.device": + requires_backends(self, ["torch"]) logger.info("PyTorch: setting up devices") if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1: logger.warning( @@ -1537,15 +1536,14 @@ def _setup_devices(self) -> "torch.device": return device @property - @torch_required def device(self) -> "torch.device": """ The device used by this process. """ + requires_backends(self, ["torch"]) return self._setup_devices @property - @torch_required def n_gpu(self): """ The number of GPUs used by this process. @@ -1554,12 +1552,12 @@ def n_gpu(self): This will only be greater than one when you have multiple GPUs available but are not using distributed training. For distributed training, it will always be 1. """ + requires_backends(self, ["torch"]) # Make sure `self._n_gpu` is properly setup. _ = self._setup_devices return self._n_gpu @property - @torch_required def parallel_mode(self): """ The current mode used for parallelism if multiple GPUs/TPU cores are available. One of: @@ -1570,6 +1568,7 @@ def parallel_mode(self): `torch.nn.DistributedDataParallel`). - `ParallelMode.TPU`: several TPU cores. """ + requires_backends(self, ["torch"]) if is_torch_tpu_available(): return ParallelMode.TPU elif is_sagemaker_mp_enabled(): @@ -1584,11 +1583,12 @@ def parallel_mode(self): return ParallelMode.NOT_PARALLEL @property - @torch_required def world_size(self): """ The number of processes used in parallel. """ + requires_backends(self, ["torch"]) + if is_torch_tpu_available(): return xm.xrt_world_size() elif is_sagemaker_mp_enabled(): @@ -1600,11 +1600,11 @@ def world_size(self): return 1 @property - @torch_required def process_index(self): """ The index of the current process used. """ + requires_backends(self, ["torch"]) if is_torch_tpu_available(): return xm.get_ordinal() elif is_sagemaker_mp_enabled(): @@ -1616,11 +1616,11 @@ def process_index(self): return 0 @property - @torch_required def local_process_index(self): """ The index of the local process used. """ + requires_backends(self, ["torch"]) if is_torch_tpu_available(): return xm.get_local_ordinal() elif is_sagemaker_mp_enabled(): diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index b3068b211a6d..3cacfba16e8f 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -17,7 +17,7 @@ from typing import Optional, Tuple from .training_args import TrainingArguments -from .utils import cached_property, is_tf_available, logging, tf_required +from .utils import cached_property, is_tf_available, logging, requires_backends logger = logging.get_logger(__name__) @@ -185,8 +185,8 @@ class TFTrainingArguments(TrainingArguments): xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"}) @cached_property - @tf_required def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]: + requires_backends(self, ["tf"]) logger.info("Tensorflow: setting up strategy") gpus = tf.config.list_physical_devices("GPU") @@ -234,19 +234,19 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]: return strategy @property - @tf_required def strategy(self) -> "tf.distribute.Strategy": """ The strategy used for distributed training. """ + requires_backends(self, ["tf"]) return self._setup_strategy @property - @tf_required def n_replicas(self) -> int: """ The number of replicas (CPUs, GPUs or TPU cores) used in this training. """ + requires_backends(self, ["tf"]) return self._setup_strategy.num_replicas_in_sync @property @@ -276,11 +276,11 @@ def eval_batch_size(self) -> int: return per_device_batch_size * self.n_replicas @property - @tf_required def n_gpu(self) -> int: """ The number of replicas (CPUs, GPUs or TPU cores) used in this training. """ + requires_backends(self, ["tf"]) warnings.warn( "The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.", FutureWarning, diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 525149417f24..353fe45e8e41 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -163,9 +163,7 @@ is_training_run_on_sagemaker, is_vision_available, requires_backends, - tf_required, torch_only_method, - torch_required, torch_version, ) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index a76b730dc0ea..d09269745297 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -22,7 +22,7 @@ import sys import warnings from collections import OrderedDict -from functools import lru_cache, wraps +from functools import lru_cache from itertools import chain from types import ModuleType from typing import Any @@ -1039,30 +1039,6 @@ def __getattribute__(cls, key): requires_backends(cls, cls._backends) -def torch_required(func): - # Chose a different decorator name than in tests so it's clear they are not the same. - @wraps(func) - def wrapper(*args, **kwargs): - if is_torch_available(): - return func(*args, **kwargs) - else: - raise ImportError(f"Method `{func.__name__}` requires PyTorch.") - - return wrapper - - -def tf_required(func): - # Chose a different decorator name than in tests so it's clear they are not the same. - @wraps(func) - def wrapper(*args, **kwargs): - if is_tf_available(): - return func(*args, **kwargs) - else: - raise ImportError(f"Method `{func.__name__}` requires TF.") - - return wrapper - - def is_torch_fx_proxy(x): if is_torch_fx_available(): import torch.fx