diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 93abac52beaa..41c8697f4234 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -44,4 +44,5 @@ pip3 install --upgrade \ junitparser==2.4.2 \ six \ tornado \ - pytest-lazy-fixture + pytest-lazy-fixture \ + ml_dtypes diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index f52e95c756bc..9fb113f56b2c 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -32,6 +32,7 @@ namespace tvm { namespace runtime { + /*! * \brief Runtime primitive data type. * @@ -54,6 +55,8 @@ class DataType { kFloat = kDLFloat, kHandle = TVMArgTypeCode::kTVMOpaqueHandle, kBFloat = kDLBfloat, + kE4M3Float = 6U, + kE5M2Float = 7U, kCustomBegin = 129 }; /*! \brief default constructor */ @@ -76,6 +79,9 @@ class DataType { if (code == kBFloat) { ICHECK_EQ(bits, 16); } + if (code == kE4M3Float || code == kE5M2Float) { + ICHECK_EQ(bits, 8); + } } /*! \return The type code. */ int code() const { return static_cast(data_.code); } @@ -91,6 +97,12 @@ class DataType { bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } /*! \return whether type is a float type. */ bool is_float() const { return code() == DataType::kFloat; } + /*! \return whether type is a float8 type. */ + bool is_float8() const { + return (code() == DataType::kFloat || code() == DataType::kE4M3Float || + code() == DataType::kE5M2Float) && + bits() == 8; + } /*! \return whether type is a float16 type. */ bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is a bfloat16 type. */ @@ -183,6 +195,18 @@ class DataType { * \return The constructed data type. */ static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); } + /*! + * \brief Construct NV float8 e4m3 datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8, lanes); } + /*! + * \brief Construct NV float8 e5m2 datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes @@ -308,6 +332,10 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { return "handle"; case kDLBfloat: return "bfloat"; + case DataType::kE4M3Float: + return "e4m3_float"; + case DataType::kE5M2Float: + return "e5m2_float"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); } @@ -376,6 +404,12 @@ inline DLDataType String2DLDataType(std::string s) { } else if (s.substr(0, 6) == "bfloat") { t.code = DataType::kBFloat; scan = s.c_str() + 6; + } else if (s.substr(0, 10) == "e4m3_float") { + t.code = DataType::kE4M3Float; + scan = s.c_str() + 10; + } else if (s.substr(0, 10) == "e5m2_float") { + t.code = DataType::kE5M2Float; + scan = s.c_str() + 10; } else if (s.substr(0, 6) == "custom") { t.code = ParseCustomDatatype(s, &scan); } else { diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 0198feb3cd79..3d5e589ab4a4 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -939,7 +939,8 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) return LargeUIntImm(t, static_cast(low), static_cast(high), span); } } - if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast(value), span); + if (t.is_float() || t.is_bfloat16() || t.is_float8()) + return FloatImm(t, static_cast(value), span); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 8dee176277d7..1caa71632d3a 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -356,12 +356,26 @@ TVM_DLL Pass NarrowDataType(int target_bits); */ TVM_DLL Pass BF16ComputeLegalize(); +/*! + * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32 + * before Ops, then add a cast back to fp8. + * \param promote_dtype_str The data type used for type promotion, defaults to float16 + * \return The pass. + */ +TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16"); + /*! * \brief Legalize bf16 storage types to u16. * \return The pass. */ TVM_DLL Pass BF16StorageLegalize(); +/*! + * \brief Legalize fp8 storage types to u8. + * \return The pass. + */ +TVM_DLL Pass FP8StorageLegalize(); + /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use diff --git a/python/gen_requirements.py b/python/gen_requirements.py index 1a55dccd1130..1cb1ce109af4 100644 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -67,6 +67,7 @@ "attrs", "cloudpickle", "decorator", + "ml_dtypes", "numpy", "psutil", "scipy", diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 999f69bc34c0..adcc3a8e972c 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -19,6 +19,11 @@ import ctypes import json import numpy as np + +try: + import ml_dtypes +except ImportError: + ml_dtypes = None from .base import _LIB, check_call tvm_shape_index_t = ctypes.c_int64 @@ -59,6 +64,8 @@ class DataTypeCode(object): FLOAT = 2 HANDLE = 3 BFLOAT = 4 + E4M3Float = 6 + E5M2Float = 7 class DataType(ctypes.Structure): @@ -71,6 +78,8 @@ class DataType(ctypes.Structure): DataTypeCode.FLOAT: "float", DataTypeCode.HANDLE: "handle", DataTypeCode.BFLOAT: "bfloat", + DataTypeCode.E4M3Float: "e4m3_float", + DataTypeCode.E5M2Float: "e5m2_float", } NUMPY2STR = { np.dtype(np.bool_): "bool", @@ -97,6 +106,8 @@ class DataType(ctypes.Structure): "uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1}, "uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1}, "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1}, + "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1}, + "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1}, "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1}, "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1}, "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1}, @@ -141,6 +152,12 @@ def __init__(self, type_str): elif head.startswith("bfloat"): self.type_code = DataTypeCode.BFLOAT head = head[6:] + elif head.startswith("e4m3_float"): + self.type_code = DataTypeCode.E4M3Float + head = head[10:] + elif head.startswith("e5m2_float"): + self.type_code = DataTypeCode.E5M2Float + head = head[10:] elif head.startswith("custom"): # pylint: disable=import-outside-toplevel import tvm.runtime._ffi_api @@ -182,6 +199,11 @@ def __ne__(self, other): return not self.__eq__(other) +if ml_dtypes is not None: + DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" + DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8" + DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8" + RPC_SESS_MASK = 128 diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 5a104be9966d..a0d1116b47e2 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -404,3 +404,20 @@ def have_bf16(compute_version): return True return False + + +def have_fp8(compute_version): + """Whether fp8 support is provided in the specified compute capability or not + + Parameters + ---------- + compute_version : str + GPU capability + """ + major, minor = parse_compute_version(compute_version) + # fp8 is suppored in Ada Lovelace (8.9) or later architectures. + if major == 8 and minor == 9: + return True + if major >= 9: + return True + return False diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index b7a325948895..7669600c49c7 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -19,6 +19,11 @@ import ctypes import warnings import numpy as np + +try: + import ml_dtypes +except ImportError: + ml_dtypes = None import tvm._ffi from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE @@ -220,6 +225,20 @@ def numpy(self): dtype = "int8" if dtype == "bfloat16": dtype = "uint16" + if dtype == "e4m3_float8": + if ml_dtypes is not None: + dtype = ml_dtypes.float8_e4m3fn + else: + raise RuntimeError( + "ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy." + ) + if dtype == "e5m2_float8": + if ml_dtypes is not None: + dtype = ml_dtypes.float8_e5m2 + else: + raise RuntimeError( + "ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy." + ) np_arr = np.empty(shape, dtype=dtype) assert np_arr.flags["C_CONTIGUOUS"] data = np_arr.ctypes.data_as(ctypes.c_void_p) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index f3aae306bea1..62576db93571 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -40,6 +40,7 @@ def Apply(ftransform): fpass : tvm.transform.Pass The result pass """ + # pylint: disable=unused-argument def _transform(func, mod, ctx): return ftransform(func) @@ -297,6 +298,22 @@ def BF16ComputeLegalize(): return _ffi_api.BF16ComputeLegalize() # type: ignore +def FP8ComputeLegalize(promote_dtype_str: str = "float32"): + """Legalize fp8 compute Ops. + + Parameters + ---------- + promote_dtype : str + The data type we promote fp8 to, options: float16/float32. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore + + def BF16StorageLegalize(): """Legalize bf16 storage types to u16. @@ -308,6 +325,17 @@ def BF16StorageLegalize(): return _ffi_api.BF16StorageLegalize() # type: ignore +def FP8StorageLegalize(): + """Legalize fp8 storage types to u8. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.FP8StorageLegalize() # type: ignore + + def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False): """Replace redundant computations by new variables. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 91bc57ccbeb2..3a66313012f2 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -210,6 +210,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::FP8ComputeLegalize()); pass_list.push_back(tir::transform::BF16ComputeLegalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -586,6 +587,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } + mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); mixed_pass_list.push_back(tir::transform::SplitHostDevice()); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 0e09568f158d..fdd8c2cd8bc5 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -104,7 +104,8 @@ TVM_REGISTER_NODE_TYPE(IntImmNode); FloatImm::FloatImm(DataType dtype, double value, Span span) { ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; - ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.code() >= DataType::kCustomBegin) + ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || + dtype.code() >= DataType::kCustomBegin) << "ValueError: FloatImm supports only float, but " << dtype << " was supplied."; // check range for float32 and float16 since they have specified range. @@ -119,6 +120,17 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; ICHECK_LE(value, support::kMaxFloat16) << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + } else if (dtype.is_bfloat16()) { + ICHECK_GE(value, -support::kMaxBFloat16) + << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; + ICHECK_LE(value, support::kMaxBFloat16) + << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + } else if (dtype.is_float8()) { + double bound = (dtype.code() == DataType::kE4M3Float) ? support::kMaxE4M3 : support::kMaxE5M2; + ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " exceeds minimum of " + << dtype; + ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << " exceeds maximum of " + << dtype; } } ObjectPtr node = make_object(); diff --git a/src/support/scalars.h b/src/support/scalars.h index 2fdbb001d922..2b34914565ed 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -65,6 +65,18 @@ FloatImm ValueToFloatImm(double value, int width); // See https://en.wikipedia.org/wiki/Half-precision_floating-point_format constexpr double kMaxFloat16 = 65504.0; +// 2^127 * (1 + 127/128) +// See https://en.wikipedia.org/wiki/Bfloat16_floating-point_format +constexpr double kMaxBFloat16 = 3.895313892515354759047080037148786688e38; + +// 2^8 * (1 + 6/8) +// See https://arxiv.org/pdf/2209.05433.pdf +constexpr double kMaxE4M3 = 448; + +// 2^15 * (1 + 3/4) +// See https://arxiv.org/pdf/2209.05433.pdf +constexpr double kMaxE5M2 = 57344; + } // namespace support } // namespace tvm diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index d2131c522e38..f17b6b3e1f58 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -116,6 +116,12 @@ std::string CodeGenCUDA::Finish() { decl_stream << _cuda_bfloat16_util; } + if (enable_fp8_) { + decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n"; + decl_stream << "#include \n"; + decl_stream << "#endif\n\n"; + } + if (enable_warp_shuffle_) { decl_stream << _cuda_warp_intrinsic_util; } @@ -249,6 +255,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) fail = true; } if (!fail) return; + } else if (t.is_float8()) { + if (t.is_scalar()) { + os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char + } else if (lanes == 2) { + os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short + } else if (lanes == 4) { + os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int + } else { + fail = true; + } + if (!fail) return; } else if (t == DataType::Bool()) { os << "bool"; return; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index bb507c179993..c6cf96d460d4 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -42,7 +42,8 @@ class CodeGenCUDA final : public CodeGenC { void Init(bool output_ssa); std::string Finish(); bool need_include_path() { - return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); + return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ || need_math_constants_h_ || + need_mma_h_); } // override behavior void PrintFuncPrefix(std::ostream& os) final; @@ -93,6 +94,8 @@ class CodeGenCUDA final : public CodeGenC { bool enable_fp16_{false}; // whether enable bf16 bool enable_bf16_{false}; + // whether enable fp8 + bool enable_fp8_{false}; // whether enable int8 bool enable_int8_{false}; // whether enable warp shuffle intrinsics diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 4439a9c3d711..39214c4546dc 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -143,6 +143,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) !rtype.is_bfloat16()) { // Cast int->bfloat16 when the other operand is a bfloat16 rhs = cast(ltype, rhs); + } else if (!ltype.is_float8() && rtype.is_float8()) { + // Cast int->float8 for lhs when rhs is a float8 + lhs = cast(rtype, lhs); + } else if (ltype.is_float8() && !rtype.is_float8()) { + // Cast int->float8 for rhs when lhs is a float8 + rhs = cast(ltype, rhs); } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) { // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 if (ltype.bits() < rtype.bits()) { @@ -165,6 +171,7 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } } } else { + LOG(INFO) << lhs << " " << rhs; LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype; } } diff --git a/src/tir/transforms/dtype_conversion.cc b/src/tir/transforms/dtype_conversion.cc new file mode 100644 index 000000000000..de94cf647387 --- /dev/null +++ b/src/tir/transforms/dtype_conversion.cc @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dtype_conversion.cc + * \brief Header file of data type conversion routines. + */ +#include "dtype_conversion.h" + +namespace tvm { +namespace tir { + +PrimExpr ReinterpretAsUInt(PrimExpr value) { + return reinterpret(GetStorageUIntDType(value.dtype()), value); +} + +DataType GetStorageUIntDType(DataType dtype) { return DataType::UInt(dtype.bits(), dtype.lanes()); } + +PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode round_mode) { + DataType src_dtype = src_value.dtype(); + // Step 1: check dtype + // The lanes of src dtype and target dtype must match. + CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes()) + << "The lanes for data type for source value must matches the target datatype."; + auto is_floating_point = [](DataType dtype) { + return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16(); + }; + // Both source dtype and target dtype should be floating point. + CHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype)); + FloatConfig src_fp = FloatConfig::FromDataType(src_value.dtype()), + tgt_fp = FloatConfig::FromDataType(tgt_dtype); + int exponent_delta = tgt_fp.exponent - src_fp.exponent; + int bias_delta = tgt_fp.bias - src_fp.bias; + int mantissa_delta = tgt_fp.mantissa - src_fp.mantissa; + DataType src_uint = GetStorageUIntDType(src_value.dtype()), + tgt_uint = GetStorageUIntDType(tgt_dtype); + PrimExpr src_uint_value = ReinterpretAsUInt(src_value); + if (mantissa_delta < 0) { + // use rounding + CHECK(round_mode == RoundingMode::kHalfToEven) + << "Currently we only support HalfToEven rounding mode."; + PrimExpr rounding_bias = ((src_uint_value >> (-mantissa_delta)) & 1) + + make_const(src_uint, (int64_t(1) << (-mantissa_delta - 1)) - 1); + src_uint_value = src_uint_value + rounding_bias; + } + if (exponent_delta == 0) { + // number of exponent bits exactly matches + PrimExpr ret = src_uint_value; + if (mantissa_delta >= 0) { + ret = cast(tgt_uint, ret) << mantissa_delta; + } else { // mantissa_delta < 0 + ret = cast(tgt_uint, ret >> (-mantissa_delta)); + } + if (bias_delta > 0) { + ret = ret + (make_const(tgt_uint, bias_delta) << tgt_fp.mantissa); + } else if (bias_delta < 0) { + ret = ret - (make_const(tgt_uint, -bias_delta) << tgt_fp.mantissa); + } + return reinterpret(tgt_dtype, ret); + } else { + // number of exponent bits mismatch. + PrimExpr ret_mantissa = + (mantissa_delta >= 0 ? (cast(tgt_uint, src_uint_value) << mantissa_delta) + : (cast(tgt_uint, src_uint_value >> (-mantissa_delta)))) & + make_const(tgt_uint, (int64_t(1) << (tgt_fp.mantissa)) - 1); + PrimExpr exponent_before_delta = ((src_uint_value << 1) >> (src_fp.mantissa + 1)); + PrimExpr ret_sign = cast(tgt_uint, (src_uint_value >> (src_fp.mantissa + src_fp.exponent))) + << (tgt_fp.mantissa + tgt_fp.exponent); + if (bias_delta >= 0) { + PrimExpr ret_exponent = + (bias_delta > 0) ? (cast(tgt_uint, exponent_before_delta + bias_delta) << tgt_fp.mantissa) + : (cast(tgt_uint, exponent_before_delta) << tgt_fp.mantissa); + return reinterpret(tgt_dtype, ret_mantissa | ret_exponent | ret_sign); + } else { // bias_delta < 0 + PrimExpr round_to_zero = exponent_before_delta < (-bias_delta); + PrimExpr ret_exponent = cast(tgt_uint, exponent_before_delta - (-bias_delta)) + << tgt_fp.mantissa; + return reinterpret(tgt_dtype, if_then_else(round_to_zero, make_const(tgt_uint, 0), + ret_mantissa | ret_exponent | ret_sign)); + } + } +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/dtype_conversion.h b/src/tir/transforms/dtype_conversion.h new file mode 100644 index 000000000000..b509abb9cd27 --- /dev/null +++ b/src/tir/transforms/dtype_conversion.h @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dtype_conversion.h + * \brief Header file of data type conversion routines. + */ +#ifndef TVM_TIR_TRANSFORMS_DTYPE_CONVERSION_H_ +#define TVM_TIR_TRANSFORMS_DTYPE_CONVERSION_H_ + +#include +#include +#include + +namespace tvm { + +namespace tir { + +/*! + * \brief Rounding mode: https://en.wikipedia.org/wiki/Rounding + */ +enum class RoundingMode { + // Round half to nearest even + kHalfToEven = 0U, + // Round down + kDown = 1U, + // Round up + kUp = 2U, + // Round towards zero + kTowardsZero = 3U, +}; + +/*! + * \brief Floating point internal representation. + */ +class FloatConfig { + public: + /*! + * \brief Style of infinite number representation. + */ + enum class InftyStyle { + // Exponent all ones, mantissa all zeros + kIEEE = 0U, + // No representation of infinity + kNone = 1U + }; + /*! + * \brief Style of NaN (not-a-number) representation. + */ + enum class NaNStyle { + // Exponent all ones, mantissa non zeros + // - quiet NaN : 1XXXXX... + // - signaling NaN : 0XXXXX... + kIEEE = 0U, + // No representation of infinity + kNone = 1U, + // Both exponent bits and mantissa bits are all ones. + kAllOnes = 2U, + }; + // The number of exponent bits. + int exponent; + // The number of mantissa (also know as fraction in IEEE format) bits. + int mantissa; + // The exponent bias in IEEE format. + int bias; + // The representation of infinity. + InftyStyle infty_style; + // The representation of NaN (Not a Number). + NaNStyle nan_style; + + FloatConfig(int exponent, int mantissa, int bias, InftyStyle infty_style, NaNStyle nan_style) + : exponent(exponent), + mantissa(mantissa), + bias(bias), + infty_style(infty_style), + nan_style(nan_style) {} + + inline int bits() const { return mantissa + exponent + 1; } + + /*! + * \brief Create float config from data type. + * \param dtype The data type, must be a floating point. + * \return The FloatConfig class containing internal floating point representation. + */ + static FloatConfig FromDataType(DataType dtype) { + CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8()) + << "FloatConfig is only applicable to floating point data types, got " << dtype + << " instead."; + if (dtype.is_float()) { + // IEEE 754 binary formats + // Reference: https://en.wikipedia.org/wiki/Floating-point_arithmetic + switch (dtype.bits()) { + case 16: + return FloatConfig(5, 10, 15, InftyStyle::kIEEE, NaNStyle::kIEEE); + case 32: + return FloatConfig(8, 23, 127, InftyStyle::kIEEE, NaNStyle::kIEEE); + default: + // float64 + return FloatConfig(11, 52, 1023, InftyStyle::kIEEE, NaNStyle::kIEEE); + } + } else if (dtype.is_bfloat16()) { + // bfloat16, + return FloatConfig(8, 7, 127, InftyStyle::kIEEE, NaNStyle::kIEEE); + } else { // float8 + // NVIDIA/Arm/Intel's FP8 formats for Deep Learning + // Reference: https://arxiv.org/abs/2209.05433 + switch (dtype.code()) { + case DataType::kE4M3Float: + // E4M3 format, not consistent with IEEE-754 + return FloatConfig(4, 3, 7, InftyStyle::kNone, NaNStyle::kAllOnes); + default: + // E5M2 format, consistent with IEEE-754 + return FloatConfig(5, 2, 15, InftyStyle::kIEEE, NaNStyle::kIEEE); + } + } + } +}; + +/*! + * \brief Reinterpret value as unsigned integer with equal number of bits. + * \param value The value to interpret. + * \return The reinterpreted uint value. + */ +PrimExpr ReinterpretAsUInt(PrimExpr value); + +/*! + * \brief Get the unsigned integer data type used as storage when the specified dtype is not + * supported natively. + * \param dtype The data type. + * \return The uint data type, the number of bits is + * the same as input dtype. + */ +DataType GetStorageUIntDType(DataType dtype); + +/*! + * \brief Conversion routine from value stored in one floating point data type to another floating + * point data type. + * \param src_value The floating point value to be converted. + * \param tgt_dtype The target floating point data type. + * \param round_mode The rounding mode to use, defaults to kHalfToEven. + * \return The converted value in target floating point data type. + * \note Used when there is no native data type conversion implementation. + */ +PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, + RoundingMode round_mode = RoundingMode::kHalfToEven); + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_TRANSFORMS_DTYPE_CONVERSION_H_ diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc similarity index 69% rename from src/tir/transforms/bf16_legalize.cc rename to src/tir/transforms/unsupported_dtype_legalize.cc index cc57735df6dd..be8876b81550 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -18,10 +18,9 @@ */ /*! - * \file bf16_legalize.cc - * \brief legalize bf16 type by adding cast_to_fp32 + * \file unsupported_dtype_legalize.cc + * \brief legalize bf16/fp8 type by adding cast_to_fp32 */ - #include #include #include @@ -31,21 +30,24 @@ #include #include +#include "dtype_conversion.h" + namespace tvm { namespace tir { // NOTE: do not touch buffer on function boundary -// remap internal bf16 buffer to f32 if they meet the following condition +// remap internal fp8/bf16 buffer to f32 if they meet the following condition // - constant allocation size // - do not have raw pointer access to the buffer // // populate the buffer_remap and var_remap accordingly. -class BF16ComputeLegalizePlanner : public StmtExprVisitor { +class ComputeLegalizePlanner : public StmtExprVisitor { public: - BF16ComputeLegalizePlanner( + ComputeLegalizePlanner( std::unordered_map* buffer_remap, - std::unordered_map* var_remap) - : buffer_remap_(buffer_remap), var_remap_(var_remap) {} + std::unordered_map* var_remap, + DataType promote_dtype) + : buffer_remap_(buffer_remap), var_remap_(var_remap), promote_dtype_(promote_dtype) {} // run planning to populate buffer remap and var remap. void Plan(PrimFunc func) { @@ -71,10 +73,12 @@ class BF16ComputeLegalizePlanner : public StmtExprVisitor { } } + virtual bool MatchDType(DataType dtype) const = 0; + void VisitStmt_(const AllocateNode* op) final { - // remap all intermediate constant buffr to fp32 - if (op->dtype.is_bfloat16() && op->ConstantAllocationSize() != 0) { - DataType dtype = DataType::Float(32, op->dtype.lanes()); + // remap all intermediate constant buffer to promote data types (fp16/fp32) + if (MatchDType(op->dtype) && op->ConstantAllocationSize() != 0) { + DataType dtype = promote_dtype_.with_lanes(op->dtype.lanes()); Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); (*var_remap_)[op->buffer_var] = buffer_var; } @@ -109,7 +113,7 @@ class BF16ComputeLegalizePlanner : public StmtExprVisitor { auto var_it = var_remap_->find(buf->data); if (var_it == var_remap_->end()) return; - Buffer new_buffer(var_it->second, DataType::Float(32, buf->dtype.lanes()), buf->shape, + Buffer new_buffer(var_it->second, promote_dtype_.with_lanes(buf->dtype.lanes()), buf->shape, buf->strides, buf->elem_offset, buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); (*buffer_remap_)[buf] = new_buffer; @@ -118,42 +122,68 @@ class BF16ComputeLegalizePlanner : public StmtExprVisitor { std::unordered_map* buffer_remap_; std::unordered_map* var_remap_; std::unordered_set opaque_var_access_; + DataType promote_dtype_; +}; + +class BF16ComputeLegalizePlanner : public ComputeLegalizePlanner { + public: + explicit BF16ComputeLegalizePlanner( + std::unordered_map* buffer_remap, + std::unordered_map* var_remap, + DataType promote_dtype) + : ComputeLegalizePlanner(buffer_remap, var_remap, promote_dtype) {} + bool MatchDType(DataType dtype) const { return dtype.is_bfloat16(); } +}; + +class FP8ComputeLegalizePlanner : public ComputeLegalizePlanner { + public: + explicit FP8ComputeLegalizePlanner( + std::unordered_map* buffer_remap, + std::unordered_map* var_remap, + DataType promote_dtype) + : ComputeLegalizePlanner(buffer_remap, var_remap, promote_dtype) {} + bool MatchDType(DataType dtype) const { return dtype.is_float8(); } }; -#define DEFINE_BIOP_EXPR_LEGALIZE(OP, FUNC) \ - PrimExpr VisitExpr_(const OP* op) final { \ - PrimExpr origin_a = PromoteBF16ToF32(this->VisitExpr(op->a)); \ - PrimExpr origin_b = PromoteBF16ToF32(this->VisitExpr(op->b)); \ - \ - if (origin_a.same_as(op->a) && origin_b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return FUNC(origin_a, origin_b); \ - } \ +#define DEFINE_BIOP_EXPR_LEGALIZE(OP, FUNC) \ + PrimExpr VisitExpr_(const OP* op) final { \ + PrimExpr origin_a = PromoteToTarget(this->VisitExpr(op->a)); \ + PrimExpr origin_b = PromoteToTarget(this->VisitExpr(op->b)); \ + \ + if (origin_a.same_as(op->a) && origin_b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return FUNC(origin_a, origin_b); \ + } \ } -// NOTE: Legalize the BF16 computations +// NOTE: Legalize the FP8/BF16 computations // to floating point computations and only keeps the -// bf16 storage which can further be legalized by BF16StorageLegalizer -// BF16StorageLegalizer will be called at a much later time +// fp8/bf16 storage which can further be legalized by FP8/BF16StorageLegalizer +// FP8/BF16StorageLegalizer will be called at a much later time // point in the TIR lowering phases. -class BF16ComputeLegalizer : public StmtExprMutator { +class ComputeLegalizer : public StmtExprMutator { public: - PrimFunc Legalize(PrimFunc func) { - BF16ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_); - planner.Plan(func); + explicit ComputeLegalizer(DataType promote_dtype) : promote_dtype_(promote_dtype) {} + + PrimFunc LegalizeWithPlanner(PrimFunc func, ComputeLegalizePlanner* planner) { + planner->Plan(func); auto* n = func.CopyOnWrite(); n->body = this->VisitStmt(std::move(n->body)); return func; } + virtual PrimFunc Legalize(PrimFunc func) = 0; + + virtual bool MatchDType(DataType dtype) const = 0; + protected: PrimExpr VisitExpr_(const CastNode* op) final { - auto op_val = PromoteBF16ToF32(this->VisitExpr(op->value)); + auto op_val = PromoteToTarget(this->VisitExpr(op->value)); - // all casts to BF16 becomes f32 - if (op->dtype.is_bfloat16()) { - return cast(DataType::Float(32, op->dtype.lanes()), op_val); + // all casts to matched data type (fp8/bf16) becomes f32 + if (MatchDType(op->dtype)) { + return cast(promote_dtype_.with_lanes(op->dtype.lanes()), op_val); } if (op_val.same_as(op->value)) { @@ -165,8 +195,8 @@ class BF16ComputeLegalizer : public StmtExprMutator { PrimExpr VisitExpr_(const SelectNode* op) final { PrimExpr condition = this->VisitExpr(op->condition); - PrimExpr true_value = PromoteBF16ToF32(this->VisitExpr(op->true_value)); - PrimExpr false_value = PromoteBF16ToF32(this->VisitExpr(op->false_value)); + PrimExpr true_value = PromoteToTarget(this->VisitExpr(op->true_value)); + PrimExpr false_value = PromoteToTarget(this->VisitExpr(op->false_value)); if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { return GetRef(op); @@ -176,7 +206,7 @@ class BF16ComputeLegalizer : public StmtExprMutator { } PrimExpr VisitExpr_(const BroadcastNode* op) final { - PrimExpr value = PromoteBF16ToF32(this->VisitExpr(op->value)); + PrimExpr value = PromoteToTarget(this->VisitExpr(op->value)); if (value.same_as(op->value)) { return GetRef(op); } else { @@ -185,7 +215,7 @@ class BF16ComputeLegalizer : public StmtExprMutator { } PrimExpr VisitExpr_(const ShuffleNode* op) final { - auto fexpr = [this](const PrimExpr& e) { return PromoteBF16ToF32(this->VisitExpr(e)); }; + auto fexpr = [this](const PrimExpr& e) { return PromoteToTarget(this->VisitExpr(e)); }; auto vectors = op->vectors.Map(fexpr); if (vectors.same_as(op->vectors)) { return GetRef(op); @@ -200,10 +230,10 @@ class BF16ComputeLegalizer : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } // update normal computations to return f32 instead. - auto fmutate = [this](const PrimExpr& e) { return PromoteBF16ToF32(this->VisitExpr(e)); }; + auto fmutate = [this](const PrimExpr& e) { return PromoteToTarget(this->VisitExpr(e)); }; Array args = op->args.Map(fmutate); - if (op->dtype.is_bfloat16()) { - return Call(DataType::Float(32, op->dtype.lanes()), op->op, args); + if (MatchDType(op->dtype)) { + return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args); } if (args.same_as(op->args)) { return GetRef(op); @@ -213,8 +243,8 @@ class BF16ComputeLegalizer : public StmtExprMutator { } PrimExpr VisitExpr_(const FloatImmNode* op) final { - if (op->dtype.is_bfloat16()) { - return FloatImm(DataType::Float(32), op->value); + if (MatchDType(op->dtype)) { + return FloatImm(promote_dtype_, op->value); } return GetRef(op); } @@ -231,7 +261,7 @@ class BF16ComputeLegalizer : public StmtExprMutator { } PrimExpr VisitExpr_(const LetNode* op) final { - PrimExpr value = PromoteBF16ToF32(op->value); + PrimExpr value = PromoteToTarget(op->value); Var var = op->var; if (value.dtype() != op->value.dtype()) { var = op->var.copy_with_dtype(op->value.dtype()); @@ -261,7 +291,7 @@ class BF16ComputeLegalizer : public StmtExprMutator { DEFINE_BIOP_EXPR_LEGALIZE(NENode, operator!=); Stmt VisitStmt_(const LetStmtNode* op) final { - PrimExpr value = PromoteBF16ToF32(op->value); + PrimExpr value = PromoteToTarget(op->value); Var var = op->var; if (value.dtype() != op->value.dtype()) { var = op->var.copy_with_dtype(op->value.dtype()); @@ -287,13 +317,16 @@ class BF16ComputeLegalizer : public StmtExprMutator { if (value.same_as(op->value) && indices.same_as(op->indices) && new_buf.same_as(op->buffer)) { return GetRef(op); } else { - if (new_buf->dtype.is_bfloat16()) { - value = CastF32ToBF16(value); + if (MatchDType(new_buf->dtype)) { + int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; + int buffer_lanes = new_buf->dtype.lanes(); + DataType legalized_dtype = new_buf->dtype.with_lanes(index_lanes * buffer_lanes); + value = CastTargetToDType(value, legalized_dtype); } if (value.dtype() != new_buf->dtype) { // this happens when buffer get rewritten to f32 - // but values remain as bf16 - ICHECK(value.dtype().is_bfloat16()); + // but values remain as fp8/bf16 + ICHECK(MatchDType(value->dtype)); value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value); } return BufferStore(new_buf, value, indices); @@ -373,41 +406,29 @@ class BF16ComputeLegalizer : public StmtExprMutator { private: /*! - * \brief promote BF16 to F32 and keep other values unchanged. + * \brief promote value to target datatype F16/F32 and keep other values unchanged. * \param value The input value. * \return The converted value. */ - PrimExpr PromoteBF16ToF32(PrimExpr value) { - if (!value.dtype().is_bfloat16()) return value; + PrimExpr PromoteToTarget(PrimExpr value) { + if (!MatchDType(value.dtype())) return value; if (const CastNode* cast = value.as()) { - if (cast->value.dtype() == DataType::Float(32)) return cast->value; + if (cast->value.dtype() == promote_dtype_.with_lanes(value.dtype().lanes())) + return cast->value; } - DataType f32 = DataType::Float(32, value.dtype().lanes()); - DataType u16 = DataType::UInt(16, value.dtype().lanes()); - DataType u32 = DataType::UInt(32, value.dtype().lanes()); - // reinterpret((cast(reinterpret(bf16_value)) << 16)) - return reinterpret(f32, cast(u32, reinterpret(u16, value)) << 16); + return DTypeConversion(value, promote_dtype_.with_lanes(value.dtype().lanes())); } /*! - * \brief Cast value to F32 to BF16 and keep other values unchanged. + * \brief Cast value from promoted datatype (FP16/FP32) back to BF16/FP8 and keep other values + * unchanged. * \param value The input value * \return The converted value. */ - PrimExpr CastF32ToBF16(PrimExpr value) { + PrimExpr CastTargetToDType(PrimExpr value, DataType dtype) { if (!value.dtype().is_float()) return value; - ICHECK_EQ(value.dtype().bits(), 32); - DataType bf16 = DataType::BFloat(16, value.dtype().lanes()); - DataType u16 = DataType::UInt(16, value.dtype().lanes()); - DataType u32 = DataType::UInt(32, value.dtype().lanes()); - PrimExpr u32_val = reinterpret(u32, value); - - if (round_to_even_) { - PrimExpr rounding_bias = ((u32_val >> 16) & 1) + make_const(u32, 0x7FFF); - u32_val = u32_val + rounding_bias; - } - // reinterpret((cast(reinterpret(f32_value)) >> 16)) - return reinterpret(bf16, cast(u16, u32_val >> 16)); + ICHECK_EQ(value.dtype(), this->promote_dtype_.with_lanes(value.dtype().lanes())); + return DTypeConversion(value, dtype); } Buffer GetRemappedBuffer(Buffer buf) { @@ -418,19 +439,40 @@ class BF16ComputeLegalizer : public StmtExprMutator { return buf; } - bool round_to_even_{true}; - + protected: + DataType promote_dtype_; std::unordered_map buffer_remap_; std::unordered_map var_remap_; }; +class BF16ComputeLegalizer : public ComputeLegalizer { + public: + BF16ComputeLegalizer() : ComputeLegalizer(DataType::Float(32)) {} + PrimFunc Legalize(PrimFunc func) { + BF16ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_, promote_dtype_); + return LegalizeWithPlanner(func, &planner); + } + bool MatchDType(DataType dtype) const { return dtype.is_bfloat16(); } +}; + +class FP8ComputeLegalizer : public ComputeLegalizer { + public: + explicit FP8ComputeLegalizer(DataType promote_dtype) : ComputeLegalizer(promote_dtype) {} + PrimFunc Legalize(PrimFunc func) { + FP8ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_, promote_dtype_); + return LegalizeWithPlanner(func, &planner); + } + bool MatchDType(DataType dtype) const { return dtype.is_float8(); } +}; + /*! - * \brief This Pass legalizes remaining BF16 storages to u16 + * \brief This Pass legalizes remaining FP8/BF16 storages to unsigned integers with equal number of + * bits. * - * This pass needs to happens after BF16ComputeLegalizer and serves - * as a way to support BF16 on platforms that do not have native support. + * This pass needs to happens after FP8/BF16ComputeLegalizer and serves + * as a way to support FP8/BF16 on platforms that do not have native support. */ -class BF16StorageLegalizer : public StmtExprMutator { +class StorageLegalizer : public StmtExprMutator { public: PrimFunc Legalize(PrimFunc func) { ICHECK_EQ(func->buffer_map.size(), 0) << "This pass must be called after MakePackedAPI"; @@ -452,8 +494,8 @@ class BF16StorageLegalizer : public StmtExprMutator { } Stmt VisitStmt_(const AllocateNode* op) final { - if (op->dtype.is_bfloat16()) { - DataType dtype = DataType::UInt(16, op->dtype.lanes()); + if (MatchDType(op->dtype)) { + DataType dtype = GetStorageUIntDType(op->dtype); Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); var_remap_[op->buffer_var] = buffer_var; return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body)); @@ -467,8 +509,8 @@ class BF16StorageLegalizer : public StmtExprMutator { // in a rare case the buffer didn't get remapped // because the original var is not bfloat* // force remap here - if (buf->dtype.is_bfloat16()) { - buf = Buffer(buf->data, DataType::UInt(16, buf->dtype.lanes()), buf->shape, buf->strides, + if (MatchDType(buf->dtype)) { + buf = Buffer(buf->data, GetStorageUIntDType(buf->dtype), buf->shape, buf->strides, buf->elem_offset, buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); buffer_remap_[op->buffer] = buf; @@ -506,13 +548,13 @@ class BF16StorageLegalizer : public StmtExprMutator { } Stmt VisitStmt_(const BufferStoreNode* op) final { - PrimExpr value = this->ChangeBF16ToU16(VisitExpr(op->value)); + PrimExpr value = this->ChangeToUInt(VisitExpr(op->value)); Buffer new_buf = GetRemappedBuffer(op->buffer); auto indices = op->indices.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); if (new_buf.same_as(op->buffer) && indices.same_as(op->indices) && value.same_as(op->value)) { return GetRef(op); } else { - if (op->value.dtype().is_bfloat16()) { + if (MatchDType(op->value.dtype())) { ICHECK(new_buf->dtype.is_uint()); } return BufferStore(new_buf, value, indices); @@ -558,8 +600,8 @@ class BF16StorageLegalizer : public StmtExprMutator { PrimExpr value = VisitExpr(op->args[0]); // sometimes the input dtype can change and we can skip. if (value.dtype() == op->dtype) return value; - if (op->dtype.is_bfloat16()) { - return reinterpret(DataType::UInt(16, op->dtype.lanes()), value); + if (MatchDType(op->dtype)) { + return reinterpret(GetStorageUIntDType(op->dtype), value); } if (op->args[0].same_as(value)) { return GetRef(op); @@ -570,17 +612,19 @@ class BF16StorageLegalizer : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } + virtual bool MatchDType(DataType dtype) const = 0; + private: /*! - * \brief Change BF16 value to U16 value. + * \brief Change float value to uint value. * \param value The input value. * \return The converted value. */ - PrimExpr ChangeBF16ToU16(PrimExpr value) { - if (!value.dtype().is_bfloat16()) return value; + PrimExpr ChangeToUInt(PrimExpr value) { + if (!MatchDType(value->dtype)) return value; auto* call = value.as(); if (call && call->op.same_as(builtin::reinterpret())) { - return reinterpret(DataType::UInt(16, value.dtype().lanes()), call->args[0]); + return reinterpret(GetStorageUIntDType(value->dtype), call->args[0]); } else { return value; } @@ -591,9 +635,9 @@ class BF16StorageLegalizer : public StmtExprMutator { if (var.dtype().is_handle()) { if (auto* ptr_type = var->type_annotation.as()) { if (auto* elem_type = ptr_type->element_type.as()) { - if (elem_type->dtype.is_bfloat16()) { - Var new_var = Var(var->name_hint, - PointerType(PrimType(DataType::UInt(16, elem_type->dtype.lanes())))); + if (MatchDType(elem_type->dtype)) { + Var new_var = + Var(var->name_hint, PointerType(PrimType(GetStorageUIntDType(elem_type->dtype)))); var_remap_[var] = new_var; return new_var; } @@ -611,13 +655,12 @@ class BF16StorageLegalizer : public StmtExprMutator { Buffer new_buf = buf; auto var_it = var_remap_.find(buf->data); if (var_it != var_remap_.end()) { - DataType dtype = - buf->dtype.is_bfloat16() ? DataType::UInt(16, buf->dtype.lanes()) : buf->dtype; + DataType dtype = MatchDType(buf->dtype) ? GetStorageUIntDType(buf->dtype) : buf->dtype; new_buf = Buffer(var_it->second, dtype, buf->shape, buf->strides, buf->elem_offset, buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); } else { - ICHECK(!buf->dtype.is_bfloat16()) << "Cannot find var remap for " << buf; + ICHECK(!MatchDType(buf->dtype)) << "Cannot find var remap for " << buf; } buffer_remap_[buf] = new_buf; @@ -629,6 +672,16 @@ class BF16StorageLegalizer : public StmtExprMutator { std::unordered_map var_remap_; }; +class BF16StorageLegalizer : public StorageLegalizer { + public: + bool MatchDType(DataType dtype) const { return dtype.is_bfloat16(); } +}; + +class FP8StorageLegalizer : public StorageLegalizer { + public: + bool MatchDType(DataType dtype) const { return dtype.is_float8(); } +}; + namespace transform { Pass BF16ComputeLegalize() { @@ -651,6 +704,27 @@ Pass BF16StorageLegalize() { TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize); +Pass FP8ComputeLegalize(String promote_dtype_str) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + // TODO(tvm-team): skip if the target supports fp8 + return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize); + +Pass FP8StorageLegalize() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + LOG(INFO) << f; + // TODO(tvm-team): skip if the target supports fp8 + return FP8StorageLegalizer().Legalize(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.FP8StorageLegalize").set_body_typed(FP8StorageLegalize); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_datatype_nv_fp8.py b/tests/python/unittest/test_datatype_nv_fp8.py new file mode 100644 index 000000000000..8313a97ee138 --- /dev/null +++ b/tests/python/unittest/test_datatype_nv_fp8.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +import tvm.tir as tir +from tvm import te +from tvm.script import tir as T + +try: + from ml_dtypes import float8_e4m3fn as e4m3_float8, float8_e5m2 as e5m2_float8 +except ImportError: + e4m3_float8, e5m2_float8 = None, None + + +def fp8_unary(dtype: str): + @T.prim_func + def func( + a: T.handle, + b: T.handle, + a_add_b: T.handle, + a_sub_b: T.handle, + a_mul_b: T.handle, + a_fp32: T.handle, + a_roundtrip: T.handle, + ) -> None: + A = T.match_buffer(a, [128], dtype=dtype) + B = T.match_buffer(b, [128], dtype=dtype) + A_add_B = T.match_buffer(a_add_b, [128], dtype=dtype) + A_sub_B = T.match_buffer(a_sub_b, [128], dtype=dtype) + A_mul_B = T.match_buffer(a_mul_b, [128], dtype=dtype) + A_fp32 = T.match_buffer(a_fp32, [128], dtype="float32") + A_roundtrip = T.match_buffer(a_roundtrip, [128], dtype=dtype) + for i in range(128): + with T.block("fp8_unary"): + vi = T.axis.spatial(128, i) + A_add_B[vi] = A[vi] + B[vi] + A_sub_B[vi] = A[vi] - B[vi] + A_mul_B[vi] = A[vi] * B[vi] + A_fp32[vi] = A[vi] + A_roundtrip[vi] = A_fp32[vi] + + return func + + +np_dtype, dtype_str = tvm.testing.parameters( + (e4m3_float8, "e4m3_float8"), (e5m2_float8, "e5m2_float8") +) + + +def test_create_nv_fp8_nd_array(np_dtype, dtype_str): + if np_dtype is None: + """Skip test if ml_dtypes is not installed""" + return + x = np.random.rand(128, 128).astype(np_dtype) + x_nd = tvm.nd.array(x) + assert x_nd.dtype == dtype_str + + +def test_fp8_unary_op(np_dtype, dtype_str): + func = fp8_unary(dtype_str) + if not tvm.testing.device_enabled("llvm"): + return + if np_dtype is None: + """Skip test if ml_dtypes is not installed""" + return + + f = tvm.build(func, target="llvm") + a = np.random.randn(128).astype(np_dtype) + b = np.random.randn(128).astype(np_dtype) + a_add_b = np.zeros(128).astype(np_dtype) + a_sub_b = np.zeros(128).astype(np_dtype) + a_mul_b = np.zeros(128).astype(np_dtype) + a_fp32 = np.zeros(128).astype(np.float32) + a_roundtrip = np.zeros(128).astype(np_dtype) + args = list( + map(lambda _: tvm.nd.array(_), [a, b, a_add_b, a_sub_b, a_mul_b, a_fp32, a_roundtrip]) + ) + f(*args) + + +def test_nv_fp8_buffer(np_dtype, dtype_str): + m = te.size_var("m") + n = te.size_var("n") + A = tvm.tir.decl_buffer((m, n), dtype_str) + assert A.dtype == dtype_str + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index ababfd489af5..20de9dc59434 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -53,12 +53,11 @@ def f32tou16(v): rounding_bias = (uint32_v >> tvm.tir.const(16, "uint32")) & tvm.tir.const(1, "uint32") rounding_bias += tvm.tir.const(0x7FFF, "uint32") uint32_v = uint32_v + rounding_bias - return uint32_v >> tvm.tir.const(16, "uint32") + return (uint32_v >> tvm.tir.const(16, "uint32")).astype("uint16") def f32tobf16(v): - uint32_v = f32tou16(v) - return T.reinterpret("bfloat16", uint32_v.astype("uint16")) + return T.reinterpret("bfloat16", f32tou16(v)) def get_after_compute_legalize(): diff --git a/tests/python/unittest/test_tir_transform_fp8_legalize.py b/tests/python/unittest/test_tir_transform_fp8_legalize.py new file mode 100644 index 000000000000..f5786808a6f3 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_fp8_legalize.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.script +import tvm.testing +from tvm.script import tir as T + +# pylint: disable=no-member,invalid-name,unused-variable + + +def get_before(dtype: str): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(Aptr: T.handle(dtype), Bptr: T.handle(dtype), Dptr: T.handle(dtype)): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), dtype, data=Aptr) + B = T.decl_buffer((100,), dtype, data=Bptr) + D = T.decl_buffer((100,), dtype, data=Dptr) + C = T.decl_buffer((100,), dtype) + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return Before + + +def promote_f8(f8_dtype: str, promote_dtype: str, v): + return promote_uint8(f8_dtype, promote_dtype, T.reinterpret("uint8", v)) + + +def cast_to_f8(f8_dtype: str, promote_dtype: str, v): + return T.reinterpret(f8_dtype, cast_to_uint8(f8_dtype, promote_dtype, v)) + + +def get_after_compute_legalize(dtype: str, promote_dtype: str): + @tvm.script.ir_module + class After: + @T.prim_func + def main(Aptr: T.handle(dtype), Bptr: T.handle(dtype), Dptr: T.handle(dtype)): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), dtype, data=Aptr) + B = T.decl_buffer((100,), dtype, data=Bptr) + D = T.decl_buffer((100,), dtype, data=Dptr) + C = T.decl_buffer((100,), promote_dtype) + for i in T.grid(100): + C[i] = promote_f8(dtype, promote_dtype, A[i]) + promote_f8( + dtype, promote_dtype, B[i] + ) + D[i] = cast_to_f8(dtype, promote_dtype, T.exp(C[i])) + + return After + + +def promote_uint8(f8_dtype: str, promote_dtype: str, v): + if f8_dtype == "e4m3_float8": + if promote_dtype == "float16": + mantissa = T.bitwise_and( + T.shift_left(T.Cast("uint16", v), T.uint16(7)), T.uint16(0x3FF) + ) + exponent = T.shift_left( + T.Cast( + "uint16", + T.shift_right(T.shift_left(v, T.uint8(1)), T.uint8(4)) + T.uint8(8), + ), + T.uint16(10), + ) + sign = T.shift_left(T.Cast("uint16", T.shift_right(v, T.uint8(7))), T.uint16(15)) + return T.reinterpret("float16", T.bitwise_or(T.bitwise_or(mantissa, exponent), sign)) + else: # promote_dtype == "float32" + mantissa = T.bitwise_and( + T.shift_left(T.Cast("uint32", v), T.uint32(20)), T.uint32(0x7FFFFF) + ) + exponent = T.shift_left( + T.Cast( + "uint32", + T.shift_right(T.shift_left(v, T.uint8(1)), T.uint8(4)) + T.uint8(120), + ), + T.uint32(23), + ) + sign = T.shift_left(T.Cast("uint32", T.shift_right(v, T.uint8(7))), T.uint32(31)) + return T.reinterpret("float32", T.bitwise_or(T.bitwise_or(mantissa, exponent), sign)) + else: # f8_dtype == "e5m2_float8" + if promote_dtype == "float16": + return T.reinterpret("float16", T.shift_left(T.Cast("uint16", v), T.uint16(8))) + else: # promote_dtype == "float32" + mantissa = T.bitwise_and( + T.shift_left(T.Cast("uint32", v), T.uint32(21)), T.uint32(0x7FFFFF) + ) + exponent = T.shift_left( + T.Cast( + "uint32", + T.shift_right(T.shift_left(v, T.uint8(1)), T.uint8(3)) + T.uint8(112), + ), + T.uint32(23), + ) + sign = T.shift_left(T.Cast("uint32", T.shift_right(v, T.uint8(7))), T.uint32(31)) + return T.reinterpret("float32", T.bitwise_or(T.bitwise_or(mantissa, exponent), sign)) + + +def cast_to_uint8(f8_dtype: str, promote_dtype: str, v): + if f8_dtype == "e4m3_float8": + if promote_dtype == "float16": + uint16_v = T.reinterpret("uint16", v) + rounding_bias = T.bitwise_and( + T.shift_right(uint16_v, T.uint16(7)), + T.uint16(1), + ) + T.uint16(0x3F) + uint16_v = uint16_v + rounding_bias + mantissa = T.bitwise_and( + T.Cast("uint8", T.shift_right(uint16_v, T.uint8(7))), T.uint8(0x7) + ) + exponent_before_delta = T.shift_right(T.shift_left(uint16_v, T.uint16(1)), T.uint16(11)) + round_to_zero = exponent_before_delta < T.uint16(8) + exponent = T.shift_left( + T.Cast("uint8", exponent_before_delta - T.uint16(8)), + T.uint8(3), + ) + sign = T.shift_left(T.Cast("uint8", T.shift_right(uint16_v, T.uint16(15))), T.uint8(7)) + return T.if_then_else( + round_to_zero, T.uint8(0), T.bitwise_or(T.bitwise_or(mantissa, exponent), sign) + ) + else: # promote_dtype == "float32" + uint32_v = T.reinterpret("uint32", v) + rounding_bias = T.bitwise_and( + T.shift_right(uint32_v, T.uint32(20)), T.uint32(1) + ) + T.uint32(0x7FFFF) + uint32_v = uint32_v + rounding_bias + mantissa = T.bitwise_and( + T.Cast("uint8", T.shift_right(uint32_v, T.uint8(20))), T.uint8(0x7) + ) + exponent_before_delta = T.shift_right(T.shift_left(uint32_v, T.uint32(1)), T.uint32(24)) + round_to_zero = exponent_before_delta < T.uint32(120) + exponent = T.shift_left( + T.Cast("uint8", exponent_before_delta - T.uint32(120)), T.uint8(3) + ) + sign = T.shift_left(T.Cast("uint8", T.shift_right(uint32_v, T.uint32(31))), T.uint8(7)) + return T.if_then_else( + round_to_zero, T.uint8(0), T.bitwise_or(T.bitwise_or(mantissa, exponent), sign) + ) + else: # f8_dtype == "e5m2_float8" + if promote_dtype == "float16": + uint16_v = T.reinterpret("uint16", v) + rounding_bias = T.bitwise_and( + T.shift_right(uint16_v, T.uint16(8)), T.uint16(1) + ) + T.uint16(0x7F) + uint16_v = uint16_v + rounding_bias + return T.Cast("uint8", T.shift_right(uint16_v, T.uint16(8))) + else: # promote_dtype == "float32" + uint32_v = T.reinterpret("uint32", v) + rounding_bias = T.bitwise_and( + T.shift_right(uint32_v, T.uint32(21)), T.uint32(1) + ) + T.uint32(0xFFFFF) + uint32_v = uint32_v + rounding_bias + mantissa = T.bitwise_and( + T.Cast("uint8", T.shift_right(uint32_v, T.uint8(21))), T.uint8(0x3) + ) + exponent_before_delta = T.shift_right(T.shift_left(uint32_v, T.uint32(1)), T.uint32(24)) + round_to_zero = exponent_before_delta < T.uint32(112) + exponent = T.shift_left( + T.Cast("uint8", exponent_before_delta - T.uint32(112)), T.uint8(2) + ) + sign = T.shift_left(T.Cast("uint8", T.shift_right(uint32_v, T.uint32(31))), T.uint8(7)) + return T.if_then_else( + round_to_zero, T.uint8(0), T.bitwise_or(T.bitwise_or(mantissa, exponent), sign) + ) + + +def get_after_storage_legalize(dtype: str, promote_dtype: str): + @tvm.script.ir_module + class After: + @T.prim_func + def main(Aptr: T.handle("uint8"), Bptr: T.handle("uint8"), Dptr: T.handle("uint8")): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "uint8", data=Aptr) + B = T.decl_buffer((100,), "uint8", data=Bptr) + D = T.decl_buffer((100,), "uint8", data=Dptr) + C = T.decl_buffer((100,), promote_dtype) + for i in T.grid(100): + C[i] = promote_uint8(dtype, promote_dtype, A[i]) + promote_uint8( + dtype, promote_dtype, B[i] + ) + D[i] = cast_to_uint8(dtype, promote_dtype, T.exp(C[i])) + + return After + + +dtype = tvm.testing.parameter("e4m3_float8", "e5m2_float8") +promote_dtype = tvm.testing.parameter("float16", "float32") + + +def test_fp8_compute_legalize(dtype, promote_dtype): + before = get_before(dtype) + expected = get_after_compute_legalize(dtype, promote_dtype) + # run the transform twice to ensure we can afford to deal + # with this repeative optimizations + after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(before) + after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(after) + tvm.ir.assert_structural_equal(after, expected) + + +def test_fp8_storage_legalize(dtype, promote_dtype): + before = get_after_compute_legalize(dtype, promote_dtype) + after = tvm.tir.transform.FP8StorageLegalize()(before) + expected = get_after_storage_legalize(dtype, promote_dtype) + tvm.ir.assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main()