Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/transformers/benchmark/benchmark_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Tuple

from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, torch_required
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends
from .benchmark_args_utils import BenchmarkArguments


Expand Down Expand Up @@ -76,8 +76,8 @@ def __init__(self, **kwargs):
)

@cached_property
@torch_required
def _setup_devices(self) -> Tuple["torch.device", int]:
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if not self.cuda:
device = torch.device("cpu")
Expand All @@ -95,19 +95,19 @@ def is_tpu(self):
return is_torch_tpu_available() and self.tpu

@property
@torch_required
def device_idx(self) -> int:
requires_backends(self, ["torch"])
# TODO(PVP): currently only single GPU is supported
return torch.cuda.current_device()

@property
@torch_required
def device(self) -> "torch.device":
requires_backends(self, ["torch"])
return self._setup_devices[0]

@property
@torch_required
def n_gpu(self):
requires_backends(self, ["torch"])
return self._setup_devices[1]

@property
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/benchmark/benchmark_args_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Tuple

from ..utils import cached_property, is_tf_available, logging, tf_required
from ..utils import cached_property, is_tf_available, logging, requires_backends
from .benchmark_args_utils import BenchmarkArguments


Expand Down Expand Up @@ -77,8 +77,8 @@ def __init__(self, **kwargs):
)

@cached_property
@tf_required
def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
requires_backends(self, ["tf"])
tpu = None
if self.tpu:
try:
Expand All @@ -91,8 +91,8 @@ def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver
return tpu

@cached_property
@tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
requires_backends(self, ["tf"])
if self.is_tpu:
tf.config.experimental_connect_to_cluster(self._setup_tpu)
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
Expand All @@ -111,23 +111,23 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.clus
return strategy

@property
@tf_required
def is_tpu(self) -> bool:
requires_backends(self, ["tf"])
return self._setup_tpu is not None

@property
@tf_required
def strategy(self) -> "tf.distribute.Strategy":
requires_backends(self, ["tf"])
return self._setup_strategy

@property
@tf_required
def gpu_list(self):
requires_backends(self, ["tf"])
return tf.config.list_physical_devices("GPU")

@property
@tf_required
def n_gpu(self) -> int:
requires_backends(self, ["tf"])
if self.cuda:
return len(self.gpu_list)
return 0
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
is_torch_device,
is_torch_dtype,
logging,
torch_required,
requires_backends,
)


Expand Down Expand Up @@ -175,7 +175,6 @@ def as_tensor(value):

return self

@torch_required
def to(self, *args, **kwargs) -> "BatchFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
Expand All @@ -190,6 +189,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":
Returns:
[`BatchFeature`]: The same instance after modification.
"""
requires_backends(self, ["torch"])
import torch # noqa

new_data = {}
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,8 @@
is_vision_available,
replace_return_docstrings,
requires_backends,
tf_required,
to_numpy,
to_py_obj,
torch_only_method,
torch_required,
torch_version,
)
4 changes: 2 additions & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
is_torch_device,
is_torch_tensor,
logging,
requires_backends,
to_py_obj,
torch_required,
)


Expand Down Expand Up @@ -739,7 +739,6 @@ def convert_to_tensors(

return self

@torch_required
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
"""
Send all values to device by calling `v.to(device)` (PyTorch only).
Expand All @@ -750,6 +749,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
Returns:
[`BatchEncoding`]: The same instance after modification.
"""
requires_backends(self, ["torch"])

# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
Expand Down
16 changes: 8 additions & 8 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
is_torch_tpu_available,
logging,
requires_backends,
torch_required,
)


Expand Down Expand Up @@ -1386,8 +1385,8 @@ def ddp_timeout_delta(self) -> timedelta:
return timedelta(seconds=self.ddp_timeout)

@cached_property
@torch_required
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning(
Expand Down Expand Up @@ -1537,15 +1536,14 @@ def _setup_devices(self) -> "torch.device":
return device

@property
@torch_required
def device(self) -> "torch.device":
"""
The device used by this process.
"""
requires_backends(self, ["torch"])
return self._setup_devices

@property
@torch_required
def n_gpu(self):
"""
The number of GPUs used by this process.
Expand All @@ -1554,12 +1552,12 @@ def n_gpu(self):
This will only be greater than one when you have multiple GPUs available but are not using distributed
training. For distributed training, it will always be 1.
"""
requires_backends(self, ["torch"])
# Make sure `self._n_gpu` is properly setup.
_ = self._setup_devices
return self._n_gpu

@property
@torch_required
def parallel_mode(self):
"""
The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
Expand All @@ -1570,6 +1568,7 @@ def parallel_mode(self):
`torch.nn.DistributedDataParallel`).
- `ParallelMode.TPU`: several TPU cores.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return ParallelMode.TPU
elif is_sagemaker_mp_enabled():
Expand All @@ -1584,11 +1583,12 @@ def parallel_mode(self):
return ParallelMode.NOT_PARALLEL

@property
@torch_required
def world_size(self):
"""
The number of processes used in parallel.
"""
requires_backends(self, ["torch"])

if is_torch_tpu_available():
return xm.xrt_world_size()
elif is_sagemaker_mp_enabled():
Expand All @@ -1600,11 +1600,11 @@ def world_size(self):
return 1

@property
@torch_required
def process_index(self):
"""
The index of the current process used.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.get_ordinal()
elif is_sagemaker_mp_enabled():
Expand All @@ -1616,11 +1616,11 @@ def process_index(self):
return 0

@property
@torch_required
def local_process_index(self):
"""
The index of the local process used.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.get_local_ordinal()
elif is_sagemaker_mp_enabled():
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/training_args_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional, Tuple

from .training_args import TrainingArguments
from .utils import cached_property, is_tf_available, logging, tf_required
from .utils import cached_property, is_tf_available, logging, requires_backends


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -185,8 +185,8 @@ class TFTrainingArguments(TrainingArguments):
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})

@cached_property
@tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
requires_backends(self, ["tf"])
logger.info("Tensorflow: setting up strategy")

gpus = tf.config.list_physical_devices("GPU")
Expand Down Expand Up @@ -234,19 +234,19 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
return strategy

@property
@tf_required
def strategy(self) -> "tf.distribute.Strategy":
"""
The strategy used for distributed training.
"""
requires_backends(self, ["tf"])
return self._setup_strategy

@property
@tf_required
def n_replicas(self) -> int:
"""
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
requires_backends(self, ["tf"])
return self._setup_strategy.num_replicas_in_sync

@property
Expand Down Expand Up @@ -276,11 +276,11 @@ def eval_batch_size(self) -> int:
return per_device_batch_size * self.n_replicas

@property
@tf_required
def n_gpu(self) -> int:
"""
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
requires_backends(self, ["tf"])
warnings.warn(
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
FutureWarning,
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@
is_training_run_on_sagemaker,
is_vision_available,
requires_backends,
tf_required,
torch_only_method,
torch_required,
torch_version,
)

Expand Down
26 changes: 1 addition & 25 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sys
import warnings
from collections import OrderedDict
from functools import lru_cache, wraps
from functools import lru_cache
from itertools import chain
from types import ModuleType
from typing import Any
Expand Down Expand Up @@ -1039,30 +1039,6 @@ def __getattribute__(cls, key):
requires_backends(cls, cls._backends)


def torch_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_torch_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")

return wrapper


def tf_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_tf_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires TF.")

return wrapper


def is_torch_fx_proxy(x):
if is_torch_fx_available():
import torch.fx
Expand Down