diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a4de8abed03d..4703c415e42f 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -52,7 +52,6 @@ find_pruneable_heads_and_indices, id_tensor_storage, is_torch_greater_or_equal_than_1_13, - is_torch_greater_or_equal_than_2_4, prune_conv1d_layer, prune_layer, prune_linear_layer, @@ -90,6 +89,7 @@ is_peft_available, is_remote_url, is_safetensors_available, + is_torch_greater_or_equal, is_torch_sdpa_available, is_torch_xla_available, logging, @@ -5032,7 +5032,7 @@ def tensor_parallel(self, device_mesh): device_mesh (`torch.distributed.DeviceMesh`): The device mesh to use for tensor parallelism. """ - if not is_torch_greater_or_equal_than_2_4: + if not is_torch_greater_or_equal("2.5"): raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.") # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 6757f72350ba..5bdf8a355ddf 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -21,7 +21,7 @@ from safetensors.torch import storage_ptr, storage_size from torch import nn -from .utils import is_torch_xla_available, logging +from .utils import is_torch_greater_or_equal, is_torch_xla_available, logging ALL_LAYERNORM_LAYERS = [nn.LayerNorm] @@ -39,7 +39,7 @@ is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") -if is_torch_greater_or_equal_than_2_4: +if is_torch_greater_or_equal("2.5"): from torch.distributed.tensor import Replicate from torch.distributed.tensor.parallel import ( ColwiseParallel,