diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 3e2d454e6e63..30dad84b16d1 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -16,9 +16,13 @@ from bisect import bisect_left import torch -from torch._six import inf from deepspeed import comm as dist +try: + from torch._six import inf as inf +except ModuleNotFoundError: + from torch import inf as inf + from deepspeed.utils import groups, logger from deepspeed.runtime.constants import PIPE_REPLICATED from numpy import prod diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d56898efa20f..084f29c27922 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -7,13 +7,12 @@ import gc import collections from typing import Deque, Dict, Tuple -from torch._six import inf from deepspeed.runtime import ZeROOptimizer from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced -from deepspeed.runtime.utils import get_global_norm, is_model_parallel_parameter +from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 0b9d9c5a6fef..8c980d6a75c1 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -5,7 +5,6 @@ import torch import os from deepspeed import comm as dist -from torch._six import inf from packaging import version as pkg_version from collections import OrderedDict @@ -15,6 +14,7 @@ get_global_norm, empty_cache, see_memory_usage, + inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups)