Skip to content

Commit b13be93

Browse files
authored
[DataType] Initial support of fp8 (e4m3/e5m2) (#14863)
Recently NVIDIA announced official support of the fp8 data type: e4m3 and e5m2, the first one has 4 bits for exponent and 3 bits for mantissa while the second one has 5 bits for exponent and 2 bits for mantissa, and NVIDIA encourages using e4m3 for forward and e5m2 (larger dynamic range) for backward. Currently, TVM has no support for these data types, as the first step to support fp8, this PR adds new type codes for `e4m3_float8` and `e5m2_float8`, and implement legalization passes `FP8ComputeLegalize` and `FP8StorageLegalize` so that we can use them for backends that do not have native fp8 support.
1 parent 8543cec commit b13be93

File tree

21 files changed

+959
-102
lines changed

21 files changed

+959
-102
lines changed

docker/install/ubuntu_install_python_package.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,5 @@ pip3 install --upgrade \
4444
junitparser==2.4.2 \
4545
six \
4646
tornado \
47-
pytest-lazy-fixture
47+
pytest-lazy-fixture \
48+
ml_dtypes

include/tvm/runtime/data_type.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
namespace tvm {
3434
namespace runtime {
35+
3536
/*!
3637
* \brief Runtime primitive data type.
3738
*
@@ -54,6 +55,8 @@ class DataType {
5455
kFloat = kDLFloat,
5556
kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
5657
kBFloat = kDLBfloat,
58+
kE4M3Float = 6U,
59+
kE5M2Float = 7U,
5760
kCustomBegin = 129
5861
};
5962
/*! \brief default constructor */
@@ -76,6 +79,9 @@ class DataType {
7679
if (code == kBFloat) {
7780
ICHECK_EQ(bits, 16);
7881
}
82+
if (code == kE4M3Float || code == kE5M2Float) {
83+
ICHECK_EQ(bits, 8);
84+
}
7985
}
8086
/*! \return The type code. */
8187
int code() const { return static_cast<int>(data_.code); }
@@ -91,6 +97,12 @@ class DataType {
9197
bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
9298
/*! \return whether type is a float type. */
9399
bool is_float() const { return code() == DataType::kFloat; }
100+
/*! \return whether type is a float8 type. */
101+
bool is_float8() const {
102+
return (code() == DataType::kFloat || code() == DataType::kE4M3Float ||
103+
code() == DataType::kE5M2Float) &&
104+
bits() == 8;
105+
}
94106
/*! \return whether type is a float16 type. */
95107
bool is_float16() const { return is_float() && bits() == 16; }
96108
/*! \return whether type is a bfloat16 type. */
@@ -183,6 +195,18 @@ class DataType {
183195
* \return The constructed data type.
184196
*/
185197
static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
198+
/*!
199+
* \brief Construct NV float8 e4m3 datatype.
200+
* \param lanes The number of lanes
201+
* \return The constructed data type.
202+
*/
203+
static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8, lanes); }
204+
/*!
205+
* \brief Construct NV float8 e5m2 datatype.
206+
* \param lanes The number of lanes
207+
* \return The constructed data type.
208+
*/
209+
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
186210
/*!
187211
* \brief Construct a bool type.
188212
* \param lanes The number of lanes
@@ -308,6 +332,10 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
308332
return "handle";
309333
case kDLBfloat:
310334
return "bfloat";
335+
case DataType::kE4M3Float:
336+
return "e4m3_float";
337+
case DataType::kE5M2Float:
338+
return "e5m2_float";
311339
default:
312340
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
313341
}
@@ -376,6 +404,12 @@ inline DLDataType String2DLDataType(std::string s) {
376404
} else if (s.substr(0, 6) == "bfloat") {
377405
t.code = DataType::kBFloat;
378406
scan = s.c_str() + 6;
407+
} else if (s.substr(0, 10) == "e4m3_float") {
408+
t.code = DataType::kE4M3Float;
409+
scan = s.c_str() + 10;
410+
} else if (s.substr(0, 10) == "e5m2_float") {
411+
t.code = DataType::kE5M2Float;
412+
scan = s.c_str() + 10;
379413
} else if (s.substr(0, 6) == "custom") {
380414
t.code = ParseCustomDatatype(s, &scan);
381415
} else {

include/tvm/tir/op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,8 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
939939
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
940940
}
941941
}
942-
if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value), span);
942+
if (t.is_float() || t.is_bfloat16() || t.is_float8())
943+
return FloatImm(t, static_cast<double>(value), span);
943944
// For now, we store const scalar values of custom datatypes within doubles; later, during the
944945
// datatypes lowering pass, we will lower the value to its true representation in the format
945946
// specified by the datatype.

include/tvm/tir/transform.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,26 @@ TVM_DLL Pass NarrowDataType(int target_bits);
394394
*/
395395
TVM_DLL Pass BF16ComputeLegalize();
396396

397+
/*!
398+
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
399+
* before Ops, then add a cast back to fp8.
400+
* \param promote_dtype_str The data type used for type promotion, defaults to float16
401+
* \return The pass.
402+
*/
403+
TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
404+
397405
/*!
398406
* \brief Legalize bf16 storage types to u16.
399407
* \return The pass.
400408
*/
401409
TVM_DLL Pass BF16StorageLegalize();
402410

411+
/*!
412+
* \brief Legalize fp8 storage types to u8.
413+
* \return The pass.
414+
*/
415+
TVM_DLL Pass FP8StorageLegalize();
416+
403417
/*!
404418
* \brief Rewrite the pointer content type of arguments,
405419
* as well as Alloc internal to the function to use

python/gen_requirements.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"attrs",
6868
"cloudpickle",
6969
"decorator",
70+
"ml_dtypes",
7071
"numpy",
7172
"psutil",
7273
"scipy",

python/tvm/_ffi/runtime_ctypes.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
import ctypes
2020
import json
2121
import numpy as np
22+
23+
try:
24+
import ml_dtypes
25+
except ImportError:
26+
ml_dtypes = None
2227
from .base import _LIB, check_call
2328

2429
tvm_shape_index_t = ctypes.c_int64
@@ -59,6 +64,8 @@ class DataTypeCode(object):
5964
FLOAT = 2
6065
HANDLE = 3
6166
BFLOAT = 4
67+
E4M3Float = 6
68+
E5M2Float = 7
6269

6370

6471
class DataType(ctypes.Structure):
@@ -71,6 +78,8 @@ class DataType(ctypes.Structure):
7178
DataTypeCode.FLOAT: "float",
7279
DataTypeCode.HANDLE: "handle",
7380
DataTypeCode.BFLOAT: "bfloat",
81+
DataTypeCode.E4M3Float: "e4m3_float",
82+
DataTypeCode.E5M2Float: "e5m2_float",
7483
}
7584
NUMPY2STR = {
7685
np.dtype(np.bool_): "bool",
@@ -97,6 +106,8 @@ class DataType(ctypes.Structure):
97106
"uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1},
98107
"uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1},
99108
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
109+
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1},
110+
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1},
100111
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
101112
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
102113
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
@@ -141,6 +152,12 @@ def __init__(self, type_str):
141152
elif head.startswith("bfloat"):
142153
self.type_code = DataTypeCode.BFLOAT
143154
head = head[6:]
155+
elif head.startswith("e4m3_float"):
156+
self.type_code = DataTypeCode.E4M3Float
157+
head = head[10:]
158+
elif head.startswith("e5m2_float"):
159+
self.type_code = DataTypeCode.E5M2Float
160+
head = head[10:]
144161
elif head.startswith("custom"):
145162
# pylint: disable=import-outside-toplevel
146163
import tvm.runtime._ffi_api
@@ -182,6 +199,11 @@ def __ne__(self, other):
182199
return not self.__eq__(other)
183200

184201

202+
if ml_dtypes is not None:
203+
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
204+
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
205+
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
206+
185207
RPC_SESS_MASK = 128
186208

187209

python/tvm/contrib/nvcc.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,3 +404,20 @@ def have_bf16(compute_version):
404404
return True
405405

406406
return False
407+
408+
409+
def have_fp8(compute_version):
410+
"""Whether fp8 support is provided in the specified compute capability or not
411+
412+
Parameters
413+
----------
414+
compute_version : str
415+
GPU capability
416+
"""
417+
major, minor = parse_compute_version(compute_version)
418+
# fp8 is suppored in Ada Lovelace (8.9) or later architectures.
419+
if major == 8 and minor == 9:
420+
return True
421+
if major >= 9:
422+
return True
423+
return False

python/tvm/runtime/ndarray.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
import ctypes
2020
import warnings
2121
import numpy as np
22+
23+
try:
24+
import ml_dtypes
25+
except ImportError:
26+
ml_dtypes = None
2227
import tvm._ffi
2328

2429
from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE
@@ -217,6 +222,20 @@ def numpy(self):
217222
dtype = "int8"
218223
if dtype == "bfloat16":
219224
dtype = "uint16"
225+
if dtype == "e4m3_float8":
226+
if ml_dtypes is not None:
227+
dtype = ml_dtypes.float8_e4m3fn
228+
else:
229+
raise RuntimeError(
230+
"ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy."
231+
)
232+
if dtype == "e5m2_float8":
233+
if ml_dtypes is not None:
234+
dtype = ml_dtypes.float8_e5m2
235+
else:
236+
raise RuntimeError(
237+
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
238+
)
220239
np_arr = np.empty(shape, dtype=dtype)
221240
assert np_arr.flags["C_CONTIGUOUS"]
222241
data = np_arr.ctypes.data_as(ctypes.c_void_p)

python/tvm/tir/transform/transform.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def Apply(ftransform):
4040
fpass : tvm.transform.Pass
4141
The result pass
4242
"""
43+
4344
# pylint: disable=unused-argument
4445
def _transform(func, mod, ctx):
4546
return ftransform(func)
@@ -297,6 +298,22 @@ def BF16ComputeLegalize():
297298
return _ffi_api.BF16ComputeLegalize() # type: ignore
298299

299300

301+
def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
302+
"""Legalize fp8 compute Ops.
303+
304+
Parameters
305+
----------
306+
promote_dtype : str
307+
The data type we promote fp8 to, options: float16/float32.
308+
309+
Returns
310+
-------
311+
fpass : tvm.transform.Pass
312+
The result pass
313+
"""
314+
return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore
315+
316+
300317
def BF16StorageLegalize():
301318
"""Legalize bf16 storage types to u16.
302319
@@ -308,6 +325,17 @@ def BF16StorageLegalize():
308325
return _ffi_api.BF16StorageLegalize() # type: ignore
309326

310327

328+
def FP8StorageLegalize():
329+
"""Legalize fp8 storage types to u8.
330+
331+
Returns
332+
-------
333+
fpass : tvm.transform.Pass
334+
The result pass
335+
"""
336+
return _ffi_api.FP8StorageLegalize() # type: ignore
337+
338+
311339
def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False):
312340
"""Replace redundant computations by new variables.
313341

src/driver/driver_api.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
210210
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
211211
pass_list.push_back(tir::transform::LowerOpaqueBlock());
212212
pass_list.push_back(tir::transform::FlattenBuffer());
213+
pass_list.push_back(tir::transform::FP8ComputeLegalize());
213214
pass_list.push_back(tir::transform::BF16ComputeLegalize());
214215
pass_list.push_back(tir::transform::NarrowDataType(32));
215216
pass_list.push_back(tir::transform::Simplify());
@@ -586,6 +587,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
586587
} else {
587588
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
588589
}
590+
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
589591
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
590592

591593
mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());

0 commit comments

Comments
 (0)