diff --git a/apex/contrib/test/optimizers/test_distributed_fused_lamb.py b/apex/contrib/test/optimizers/test_distributed_fused_lamb.py index f38f371b7..b1ec58c9a 100644 --- a/apex/contrib/test/optimizers/test_distributed_fused_lamb.py +++ b/apex/contrib/test/optimizers/test_distributed_fused_lamb.py @@ -1,12 +1,22 @@ -import os import inspect + import torch from torch.cuda.amp import GradScaler from torch.testing._internal import common_utils -from apex.parallel.distributed import flat_dist_call +from torch.distributed.distributed_c10d import _coalescing_manager + from apex.contrib.optimizers.distributed_fused_lamb import DistributedFusedLAMB from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase + +def flat_dist_call(param_list: list[torch.Tensor], op, args): + with _coalescing_manager(async_ops=True) as cm: + for p in param_list: + op(p, *args) + + cm.wait() + + def get_init_weights_func(): @torch.no_grad() def init_weights(m):