Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
find_pruneable_heads_and_indices,
id_tensor_storage,
is_torch_greater_or_equal_than_1_13,
is_torch_greater_or_equal_than_2_4,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
Expand Down Expand Up @@ -5005,6 +5006,8 @@ def tensor_parallel(self, device_mesh):
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
if not is_torch_greater_or_equal_than_2_4:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ArthurZucker - this mismatch between the torch 2.4 check and 2.5 requirement means that torch 2.4 still hits this issue (where torch 2.3 is now working properly)

raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")

# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)

from .utils import is_torch_xla_available, logging

Expand All @@ -44,6 +39,14 @@
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")


if is_torch_greater_or_equal_than_2_4:
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)
Comment on lines +42 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, Replicate is in torch.distributed.tensor >= torch 2.5.
Between 2.0-2.4 (included), it is in torch.distributed._tensor.

Thus it seems there are two options:
Option 1:

try:
    from torch.distributed.tensor import Replicate
except ImportError:
    from torch.distributed._tensor import Replicate

Option 2:
bump the requirement to 2.5.

ColwiseParallel and RowwiseParallel well exists since 2.0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kwen2501 - no preference from my side. I have a PR that does the first option here: #34816

But happy to abandon or switch in favor of any you have to add support.



def softmax_backward_data(parent, grad_output, output, dim, self):
"""
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
Expand Down