Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ pip3 install --upgrade \
junitparser==2.4.2 \
six \
tornado \
pytest-lazy-fixture
pytest-lazy-fixture \
ml_dtypes
34 changes: 34 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

namespace tvm {
namespace runtime {

/*!
* \brief Runtime primitive data type.
*
Expand All @@ -54,6 +55,8 @@ class DataType {
kFloat = kDLFloat,
kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
kBFloat = kDLBfloat,
kE4M3Float = 6U,
kE5M2Float = 7U,
kCustomBegin = 129
};
/*! \brief default constructor */
Expand All @@ -76,6 +79,9 @@ class DataType {
if (code == kBFloat) {
ICHECK_EQ(bits, 16);
}
if (code == kE4M3Float || code == kE5M2Float) {
ICHECK_EQ(bits, 8);
}
}
/*! \return The type code. */
int code() const { return static_cast<int>(data_.code); }
Expand All @@ -91,6 +97,12 @@ class DataType {
bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
/*! \return whether type is a float type. */
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a float8 type. */
bool is_float8() const {
return (code() == DataType::kFloat || code() == DataType::kE4M3Float ||
code() == DataType::kE5M2Float) &&
bits() == 8;
}
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
Expand Down Expand Up @@ -183,6 +195,18 @@ class DataType {
* \return The constructed data type.
*/
static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
/*!
* \brief Construct NV float8 e4m3 datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8, lanes); }
/*!
* \brief Construct NV float8 e5m2 datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes
Expand Down Expand Up @@ -308,6 +332,10 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
return "handle";
case kDLBfloat:
return "bfloat";
case DataType::kE4M3Float:
return "e4m3_float";
case DataType::kE5M2Float:
return "e5m2_float";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
Expand Down Expand Up @@ -376,6 +404,12 @@ inline DLDataType String2DLDataType(std::string s) {
} else if (s.substr(0, 6) == "bfloat") {
t.code = DataType::kBFloat;
scan = s.c_str() + 6;
} else if (s.substr(0, 10) == "e4m3_float") {
t.code = DataType::kE4M3Float;
scan = s.c_str() + 10;
} else if (s.substr(0, 10) == "e5m2_float") {
t.code = DataType::kE5M2Float;
scan = s.c_str() + 10;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,8 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
}
}
if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value), span);
if (t.is_float() || t.is_bfloat16() || t.is_float8())
return FloatImm(t, static_cast<double>(value), span);
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,26 @@ TVM_DLL Pass NarrowDataType(int target_bits);
*/
TVM_DLL Pass BF16ComputeLegalize();

/*!
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
* before Ops, then add a cast back to fp8.
* \param promote_dtype_str The data type used for type promotion, defaults to float16
* \return The pass.
*/
TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");

/*!
* \brief Legalize bf16 storage types to u16.
* \return The pass.
*/
TVM_DLL Pass BF16StorageLegalize();

/*!
* \brief Legalize fp8 storage types to u8.
* \return The pass.
*/
TVM_DLL Pass FP8StorageLegalize();

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand Down
1 change: 1 addition & 0 deletions python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"attrs",
"cloudpickle",
"decorator",
"ml_dtypes",
"numpy",
"psutil",
"scipy",
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import ctypes
import json
import numpy as np

try:
import ml_dtypes
except ImportError:
ml_dtypes = None
from .base import _LIB, check_call

tvm_shape_index_t = ctypes.c_int64
Expand Down Expand Up @@ -59,6 +64,8 @@ class DataTypeCode(object):
FLOAT = 2
HANDLE = 3
BFLOAT = 4
E4M3Float = 6
E5M2Float = 7


class DataType(ctypes.Structure):
Expand All @@ -71,6 +78,8 @@ class DataType(ctypes.Structure):
DataTypeCode.FLOAT: "float",
DataTypeCode.HANDLE: "handle",
DataTypeCode.BFLOAT: "bfloat",
DataTypeCode.E4M3Float: "e4m3_float",
DataTypeCode.E5M2Float: "e5m2_float",
}
NUMPY2STR = {
np.dtype(np.bool_): "bool",
Expand All @@ -97,6 +106,8 @@ class DataType(ctypes.Structure):
"uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1},
"uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1},
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1},
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1},
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
Expand Down Expand Up @@ -141,6 +152,12 @@ def __init__(self, type_str):
elif head.startswith("bfloat"):
self.type_code = DataTypeCode.BFLOAT
head = head[6:]
elif head.startswith("e4m3_float"):
self.type_code = DataTypeCode.E4M3Float
head = head[10:]
elif head.startswith("e5m2_float"):
self.type_code = DataTypeCode.E5M2Float
head = head[10:]
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
Expand Down Expand Up @@ -182,6 +199,11 @@ def __ne__(self, other):
return not self.__eq__(other)


if ml_dtypes is not None:
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"

RPC_SESS_MASK = 128


Expand Down
17 changes: 17 additions & 0 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,20 @@ def have_bf16(compute_version):
return True

return False


def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not

Parameters
----------
compute_version : str
GPU capability
"""
major, minor = parse_compute_version(compute_version)
# fp8 is suppored in Ada Lovelace (8.9) or later architectures.
if major == 8 and minor == 9:
return True
if major >= 9:
return True
return False
19 changes: 19 additions & 0 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import ctypes
import warnings
import numpy as np

try:
import ml_dtypes
except ImportError:
ml_dtypes = None
import tvm._ffi

from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE
Expand Down Expand Up @@ -220,6 +225,20 @@ def numpy(self):
dtype = "int8"
if dtype == "bfloat16":
dtype = "uint16"
if dtype == "e4m3_float8":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy."
)
if dtype == "e5m2_float8":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e5m2
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def Apply(ftransform):
fpass : tvm.transform.Pass
The result pass
"""

# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return ftransform(func)
Expand Down Expand Up @@ -297,6 +298,22 @@ def BF16ComputeLegalize():
return _ffi_api.BF16ComputeLegalize() # type: ignore


def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
"""Legalize fp8 compute Ops.

Parameters
----------
promote_dtype : str
The data type we promote fp8 to, options: float16/float32.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore


def BF16StorageLegalize():
"""Legalize bf16 storage types to u16.

Expand All @@ -308,6 +325,17 @@ def BF16StorageLegalize():
return _ffi_api.BF16StorageLegalize() # type: ignore


def FP8StorageLegalize():
"""Legalize fp8 storage types to u8.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FP8StorageLegalize() # type: ignore


def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False):
"""Replace redundant computations by new variables.

Expand Down
2 changes: 2 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::FP8ComputeLegalize());
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand Down Expand Up @@ -586,6 +587,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());

Expand Down
14 changes: 13 additions & 1 deletion src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ TVM_REGISTER_NODE_TYPE(IntImmNode);
FloatImm::FloatImm(DataType dtype, double value, Span span) {
ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";

ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.code() >= DataType::kCustomBegin)
ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() ||
dtype.code() >= DataType::kCustomBegin)
<< "ValueError: FloatImm supports only float, but " << dtype << " was supplied.";

// check range for float32 and float16 since they have specified range.
Expand All @@ -119,6 +120,17 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) {
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LE(value, support::kMaxFloat16)
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
} else if (dtype.is_bfloat16()) {
ICHECK_GE(value, -support::kMaxBFloat16)
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LE(value, support::kMaxBFloat16)
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
} else if (dtype.is_float8()) {
double bound = (dtype.code() == DataType::kE4M3Float) ? support::kMaxE4M3 : support::kMaxE5M2;
ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " exceeds minimum of "
<< dtype;
ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << " exceeds maximum of "
<< dtype;
}
}
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
Expand Down
12 changes: 12 additions & 0 deletions src/support/scalars.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ FloatImm ValueToFloatImm(double value, int width);
// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format
constexpr double kMaxFloat16 = 65504.0;

// 2^127 * (1 + 127/128)
// See https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
constexpr double kMaxBFloat16 = 3.895313892515354759047080037148786688e38;

// 2^8 * (1 + 6/8)
// See https://arxiv.org/pdf/2209.05433.pdf
constexpr double kMaxE4M3 = 448;

// 2^15 * (1 + 3/4)
// See https://arxiv.org/pdf/2209.05433.pdf
constexpr double kMaxE5M2 = 57344;

} // namespace support
} // namespace tvm

Expand Down
17 changes: 17 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ std::string CodeGenCUDA::Finish() {
decl_stream << _cuda_bfloat16_util;
}

if (enable_fp8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
decl_stream << "#include <cuda_fp8.h>\n";
decl_stream << "#endif\n\n";
}

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
}
Expand Down Expand Up @@ -249,6 +255,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
fail = true;
}
if (!fail) return;
} else if (t.is_float8()) {
if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail) return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
Expand Down
Loading