Skip to content

Commit

Permalink
[Kernel][Misc] dynamo support for ScalarType (vllm-project#7594)
Browse files Browse the repository at this point in the history
  • Loading branch information
bnellnm authored Aug 16, 2024
1 parent 9f69856 commit 7759ae9
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 24 deletions.
30 changes: 30 additions & 0 deletions csrc/core/scalar_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ class ScalarType {
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
// See also:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
public:
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
Expand Down Expand Up @@ -382,6 +384,29 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
}

// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t len() const {
throw c10::TypeError("__len__ not implemented");
return 0;
}

// Serialize a ScalarType into a tuple of pairs. Where each pair
// is a (fieldname, value).
// For simplicity, we are just going to convert to a ScalarTypeId.
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
return {{"ScalarType", id()}};
}

// Deserialize a scalar type that has been serialized by obj_flatten,
// ostensibly from a tuple of (member name, value) pairs, but in reality
// just a ScalarTypeId.
static SelfPtr obj_unflatten(
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
return c10::make_intrusive<Self>(
from_id(std::get<1>(std::get<0>(flat_type))));
}

template <typename T>
static void bind_readonly_property(torch::class_<Self>& cls,
std::string const& name, T Base::*field) {
Expand Down Expand Up @@ -457,6 +482,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
self.get()->min());
});

bind_function(cls, "__len__", &ScalarTypeTorch::len);
bind_function(cls, "__str__", &Base::str);
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
return *self == *other;
Expand All @@ -465,6 +491,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
return "ScalarType." + self.get()->str();
});

bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
bind_static_function(cls, "__obj_unflatten__",
&ScalarTypeTorch::obj_unflatten);

// Bind static functions (convenience constructors)
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
Expand Down
143 changes: 119 additions & 24 deletions vllm/_core_ext.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib.util
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -31,14 +31,14 @@ class NanRepr(Enum):
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""

Expand All @@ -51,15 +51,15 @@ class ScalarType:
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""

bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""

Expand All @@ -73,7 +73,7 @@ class ScalarType:

nan_repr: int = NanRepr.IEEE_754.value
"""
How NaNs are represent in this scalar type, returns NanRepr value.
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""

Expand All @@ -83,14 +83,14 @@ def size_bits(self):

def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError

def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
Expand All @@ -103,28 +103,28 @@ def is_signed(self) -> bool:
"""
...

def is_floating_point(self):
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0

def is_integer(self):
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0

def has_bias(self):
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0

def has_infs(self):
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only

def has_nans(self):
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE.value

def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and \
Expand All @@ -136,6 +136,11 @@ def __str__(self) -> str:
def __repr__(self) -> str:
raise NotImplementedError

# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError

#
# Convenience Constructors
#
Expand All @@ -153,16 +158,16 @@ def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
"""
Create a standard floating point type
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True)

@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int):
nan_repr: int) -> 'ScalarType':
"""
Create a non-standard floating point type
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True, finite_values_only,
Expand All @@ -175,3 +180,93 @@ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
logger.warning("Failed to import from vllm._core_C with %r", e)

ScalarType = torch.classes._core_C.ScalarType

# Needed for dynamo support of ScalarType.
@torch._library.register_fake_class("_core_C::ScalarType")
class FakeScalarType:

def __init__(self, scalar_type):
self.ScalarType = scalar_type

def bias_getter(self) -> int:
return self.ScalarType.bias

def exponent_getter(self) -> int:
return self.ScalarType.exponent

def mantissa_getter(self) -> int:
return self.ScalarType.mantissa

def signed_getter(self) -> bool:
return self.ScalarType.signed

def size_bits_getter(self) -> int:
return self.ScalarType.size_bits

@property
def size_bits(self) -> int:
return self.ScalarType.size_bits

def min(self) -> Union[int, float]:
return self.ScalarType.min()

def max(self) -> Union[int, float]:
return self.ScalarType.max()

def is_signed(self) -> bool:
return self.ScalarType.is_signed()

def is_floating_point(self) -> bool:
return self.ScalarType.is_floating_point()

def is_integer(self) -> bool:
return self.ScalarType.is_integer()

def has_bias(self) -> bool:
return self.ScalarType.has_bias()

def has_infs(self) -> bool:
return self.ScalarType.has_infs()

def has_nans(self) -> bool:
return self.ScalarType.has_nans()

def is_ieee_754(self) -> bool:
return self.ScalarType.is_ieee_754()

def __str__(self) -> str:
return self.ScalarType.__str__()

def __repr__(self) -> str:
return self.ScalarType.__repr__()

def __len__(self) -> int:
return self.ScalarType.__len__()

def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
return torch.classes._core_C.ScalarType.__obj_flatten__(
self.ScalarType)

@classmethod
def __obj_unflatten__(
cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType':
return cls(
torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type))

@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.int_(size_bits, bias)

@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.uint(size_bits, bias)

@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
return ScalarType.float_IEEE754(exponent, mantissa)

@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int) -> 'ScalarType':
return ScalarType.float_(exponent, mantissa, finite_values_only,
nan_repr)

0 comments on commit 7759ae9

Please sign in to comment.