diff --git a/flash_attn/utils/distributed.py b/flash_attn/utils/distributed.py index 74c55279645..6fe93790154 100644 --- a/flash_attn/utils/distributed.py +++ b/flash_attn/utils/distributed.py @@ -9,9 +9,11 @@ # version of PyTorch. The following 4 lines are for backward compatibility with # older PyTorch. if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + if hasattr(torch.distributed, "_all_gather_base"): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base if "reduce_scatter_tensor" not in dir(torch.distributed): - torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + if hasattr(torch.distributed, "_reduce_scatter_base"): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base # Raw operation, does not support autograd, but does support async diff --git a/third_party/aiter b/third_party/aiter index 428e8e761c7..b4b75165fbd 160000 --- a/third_party/aiter +++ b/third_party/aiter @@ -1 +1 @@ -Subproject commit 428e8e761c7bc22d03513bcb8507375afef1f916 +Subproject commit b4b75165fbd2456dfd0f074c5b2ef91bc87d97e5