diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 6757f72350ba..470054a80fa9 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -30,6 +30,7 @@ parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) +is_torch_greater_or_equal_than_2_5 = parsed_torch_version_base >= version.parse("2.5") is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4") is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3") is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2") @@ -39,7 +40,7 @@ is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") -if is_torch_greater_or_equal_than_2_4: +if is_torch_greater_or_equal_than_2_5: from torch.distributed.tensor import Replicate from torch.distributed.tensor.parallel import ( ColwiseParallel,