From 67e9d636f2481878d827e1b1528f3c0517962d6d Mon Sep 17 00:00:00 2001 From: rraminen Date: Wed, 3 May 2023 18:13:12 -0400 Subject: [PATCH] Remove deprecated torch._six imports --- deepspeed/runtime/utils.py | 6 +++++- deepspeed/runtime/zero/stage3.py | 3 +-- deepspeed/runtime/zero/stage_1_and_2.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index ff6daba1c6c9..b06015e6c4ef 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -16,9 +16,13 @@ from bisect import bisect_left, bisect_right import torch -from torch._six import inf import torch.distributed as dist +try: + from torch._six import inf as inf +except ModuleNotFoundError: + from torch import inf as inf + from deepspeed.utils import groups, logger from deepspeed.runtime.constants import PIPE_REPLICATED from numpy import prod diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e963ef643677..a548989729b4 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -16,7 +16,6 @@ from torch.nn import Module, Parameter import torch.distributed as dist import math -from torch._six import inf from torch.nn import Module from torch.nn.parameter import Parameter @@ -24,7 +23,7 @@ from deepspeed.utils.logging import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced -from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim +from deepspeed.runtime.utils import inf, get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 6e3fd3e0612b..3061c5ba39e1 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -5,12 +5,12 @@ import torch from torch.distributed.distributed_c10d import _get_global_rank import torch.distributed as dist -from torch._six import inf from packaging import version as pkg_version from deepspeed.runtime import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, +from deepspeed.runtime.utils import (inf, + bwc_tensor_model_parallel_rank, get_global_norm, see_memory_usage, is_model_parallel_parameter,