Skip to content

Commit 64aea56

Browse files
committed
Avoid importing apex transformer automatically and make error messages more clear when apex.transformer is explicitly called on unsupported platform
1 parent a7de60e commit 64aea56

File tree

2 files changed

+1
-1
lines changed

2 files changed

+1
-1
lines changed

apex/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
# load time) the error message is timely and visible.
2525
from . import optimizers
2626
from . import normalization
27-
from . import transformer
2827

2928

3029
# Logging utilities for apex.transformer module

apex/transformer/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# The following 4 lines are for backward comparability with
99
# older PyTorch.
1010
if "all_gather_into_tensor" not in dir(torch.distributed):
11+
assert torch.distributed.is_available(), "PyTorch Distributed is Not available or Disabled."
1112
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
1213

1314
def ensure_divisibility(numerator, denominator):

0 commit comments

Comments
 (0)