Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 1 addition & 28 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,12 @@
import torch
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.config_utils import DeepSpeedConfigModel, DtypeEnum
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from pydantic import Field
from pydantic import validator
from typing import Dict, Union
from enum import Enum


class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
int8 = torch.int8, "torch.int8", "int8"

# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
def __new__(cls, *values):
obj = object.__new__(cls)
# first value is canonical value
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj

def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)


class MoETypeEnum(str, Enum):
residual = "residual"
standard = "standard"
Expand Down
Loading