Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
from bisect import bisect_left

import torch
from torch._six import inf
from deepspeed import comm as dist

try:
from torch._six import inf as inf
except ModuleNotFoundError:
from torch import inf as inf

from deepspeed.utils import groups, logger
from deepspeed.runtime.constants import PIPE_REPLICATED
from numpy import prod
Expand Down
3 changes: 1 addition & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import gc
import collections
from typing import Deque, Dict, Tuple
from torch._six import inf

from deepspeed.runtime import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced
from deepspeed.runtime.utils import get_global_norm, is_model_parallel_parameter
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import os
from deepspeed import comm as dist
from torch._six import inf
from packaging import version as pkg_version
from collections import OrderedDict

Expand All @@ -15,6 +14,7 @@
get_global_norm,
empty_cache,
see_memory_usage,
inf,
is_model_parallel_parameter,
align_dense_tensors,
all_gather_dp_groups)
Expand Down