diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index a45d0c256a76..7900376269b7 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -15,9 +15,13 @@ from bisect import bisect_left, bisect_right import torch -from torch._six import inf import torch.distributed as dist +try: + from torch._six import inf as inf +except ModuleNotFoundError: + from torch import inf as inf + from deepspeed.utils import groups, logger from numpy import prod diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 42a85a755629..fab837b477e6 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -6,14 +6,13 @@ from torch.distributed.distributed_c10d import _get_global_rank import torch.distributed as dist import math -from torch._six import inf from torch.autograd import Variable from packaging import version as pkg_version import collections from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, get_global_norm, see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.utils import inf, bwc_tensor_model_parallel_rank, get_global_norm, see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.runtime.zero.offload_constants import OFFLOAD_CPU_DEVICE, OFFLOAD_OPTIMIZER, OFFLOAD_OPTIMIZER_DEVICE from deepspeed.ops.adam import DeepSpeedCPUAdam diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 1387ced6fc8d..84d1147f0a32 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -11,12 +11,11 @@ from torch.distributed.distributed_c10d import _get_global_rank import torch.distributed as dist import math -from torch._six import inf from torch.autograd import Variable from deepspeed.utils.logging import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim +from deepspeed.runtime.utils import inf, get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS