Skip to content

Commit c19e5f4

Browse files
authored
[CUDA] FP4 cast and reinterpret support (#17708)
* [CUDA] FP4 cast and reinterpret support Following up on a previous PR, this PR introduces the cast and reinterpret support between `__nv_fp4_e2m1` and other dtypes. This PR also makes sure that the cast and reinterpret support vectorize. * change to float4_e2m1fn
1 parent e35a424 commit c19e5f4

File tree

13 files changed

+326
-67
lines changed

13 files changed

+326
-67
lines changed

include/tvm/runtime/data_type.h

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class DataType {
5858
kBFloat = kDLBfloat,
5959
kE4M3Float = 6U,
6060
kE5M2Float = 7U,
61-
kE2M1Float = 8U,
61+
kFloat4E2M1Fn = 8U,
6262
kCustomBegin = 129
6363
};
6464
/*! \brief default constructor */
@@ -88,7 +88,7 @@ class DataType {
8888
if (code == kE4M3Float || code == kE5M2Float) {
8989
ICHECK_EQ(bits, 8);
9090
}
91-
if (code == kE2M1Float) {
91+
if (code == kFloat4E2M1Fn) {
9292
ICHECK_EQ(bits, 4);
9393
}
9494
}
@@ -131,12 +131,10 @@ class DataType {
131131
bits() == 8;
132132
}
133133
/*! \return whether type is a float4 type. */
134-
bool is_float4() const { return code() == DataType::kE2M1Float && bits() == 4; }
134+
bool is_float4() const { return code() == DataType::kFloat4E2M1Fn && bits() == 4; }
135135
bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); }
136-
137136
bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); }
138-
139-
bool is_e2m1_float4() const { return (code() == DataType::kE2M1Float && bits() == 4); }
137+
bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4E2M1Fn && bits() == 4); }
140138
/*! \return whether type is a float16 type. */
141139
bool is_float16() const { return is_float() && bits() == 16; }
142140
/*! \return whether type is a bfloat16 type. */
@@ -262,11 +260,11 @@ class DataType {
262260
*/
263261
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
264262
/*!
265-
* \brief Construct NV float4 e2m1 datatype.
263+
* \brief Construct NV float4_e2m1fn datatype.
266264
* \param lanes The number of lanes
267265
* \return The constructed data type.
268266
*/
269-
static DataType NVFloat4E2M1(int lanes = 1) { return DataType(kE2M1Float, 4, lanes); }
267+
static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4E2M1Fn, 4, lanes); }
270268
/*!
271269
* \brief Construct a bool type.
272270
* \param lanes The number of lanes.
@@ -313,7 +311,7 @@ inline int GetVectorBytes(DataType dtype) {
313311
int data_bits = dtype.bits() * dtype.lanes();
314312
// allow bool to exist
315313
if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
316-
dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1()) {
314+
dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN()) {
317315
return 1;
318316
}
319317
ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
@@ -399,8 +397,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
399397
return "e4m3_float";
400398
case DataType::kE5M2Float:
401399
return "e5m2_float";
402-
case DataType::kE2M1Float:
403-
return "e2m1_float";
400+
case DataType::kFloat4E2M1Fn:
401+
return "float4_e2m1fn";
404402
default:
405403
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
406404
}
@@ -458,6 +456,18 @@ inline DLDataType String2DLDataType(std::string s) {
458456
} else if (s.substr(0, 4) == "uint") {
459457
t.code = kDLUInt;
460458
scan = s.c_str() + 4;
459+
} else if (s.substr(0, 13) == "float4_e2m1fn") {
460+
// Avoid being treated as "float"
461+
t.code = DataType::kFloat4E2M1Fn;
462+
t.bits = 4;
463+
scan = s.c_str() + 13;
464+
char* endpt = nullptr;
465+
if (*scan == 'x') {
466+
t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
467+
scan = endpt;
468+
}
469+
ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
470+
return t;
461471
} else if (s.substr(0, 5) == "float") {
462472
t.code = kDLFloat;
463473
scan = s.c_str() + 5;
@@ -482,10 +492,6 @@ inline DLDataType String2DLDataType(std::string s) {
482492
t.code = DataType::kE5M2Float;
483493
t.bits = 8;
484494
scan = s.c_str() + 10;
485-
} else if (s.substr(0, 10) == "e2m1_float") {
486-
t.code = DataType::kE2M1Float;
487-
t.bits = 4;
488-
scan = s.c_str() + 10;
489495
} else if (s.substr(0, 6) == "custom") {
490496
t.code = ParseCustomDatatype(s, &scan);
491497
} else {

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
505505
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, DataType::NVFloat8E4M3);
506506
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, DataType::NVFloat8E5M2);
507507

508-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E2M1Float4, DataType::NVFloat4E2M1);
508+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1fn, DataType::NVFloat4E2M1FN);
509509

510510
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
511511
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());

python/tvm/_ffi/runtime_ctypes.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class DataTypeCode(object):
6868
BFLOAT = 4
6969
E4M3Float = 6
7070
E5M2Float = 7
71-
E2M1Float = 8
71+
FLOAT4E2M1FN = 8
7272

7373

7474
class DataType(ctypes.Structure):
@@ -83,7 +83,7 @@ class DataType(ctypes.Structure):
8383
DataTypeCode.BFLOAT: "bfloat",
8484
DataTypeCode.E4M3Float: "e4m3_float",
8585
DataTypeCode.E5M2Float: "e5m2_float",
86-
DataTypeCode.E2M1Float: "e2m1_float",
86+
DataTypeCode.FLOAT4E2M1FN: "float4_e2m1fn",
8787
}
8888
NUMPY2STR = {
8989
np.dtype(np.bool_): "bool",
@@ -114,7 +114,7 @@ class DataType(ctypes.Structure):
114114
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
115115
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1},
116116
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1},
117-
"e2m1_float4": {"type_code": DataTypeCode.E2M1Float, "bits": 4, "lanes": 1},
117+
"float4_e2m1fn": {"type_code": DataTypeCode.FLOAT4E2M1FN, "bits": 4, "lanes": 1},
118118
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
119119
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
120120
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
@@ -155,6 +155,11 @@ def __init__(self, type_str):
155155
elif head.startswith("uint"):
156156
self.type_code = DataTypeCode.UINT
157157
head = head[4:]
158+
elif head.startswith("float4_e2m1fn"):
159+
# Avoid being treated as "float"
160+
self.type_code = DataTypeCode.FLOAT4E2M1FN
161+
bits = 4
162+
head = ""
158163
elif head.startswith("float"):
159164
self.type_code = DataTypeCode.FLOAT
160165
head = head[5:]
@@ -171,9 +176,6 @@ def __init__(self, type_str):
171176
elif head.startswith("e5m2_float"):
172177
self.type_code = DataTypeCode.E5M2Float
173178
head = head[10:]
174-
elif head.startswith("e2m1_float"):
175-
self.type_code = DataTypeCode.E2M1Float
176-
head = head[10:]
177179
elif head.startswith("custom"):
178180
# pylint: disable=import-outside-toplevel
179181
import tvm.runtime._ffi_api
@@ -201,7 +203,12 @@ def __repr__(self):
201203
import tvm.runtime._ffi_api
202204

203205
type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
204-
x = "%s%d" % (type_name, self.bits)
206+
if self.type_code in [
207+
DataTypeCode.FLOAT4E2M1FN,
208+
]:
209+
x = type_name
210+
else:
211+
x = "%s%d" % (type_name, self.bits)
205212
lanes_as_int = ctypes.c_int16(self.lanes).value
206213
if lanes_as_int > 1:
207214
x += "x%d" % self.lanes
@@ -238,7 +245,7 @@ def itemsize(self):
238245
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
239246
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
240247
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
241-
DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "e2m1_float4"
248+
DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
242249

243250
RPC_SESS_MASK = 128
244251

python/tvm/runtime/ndarray.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,9 @@ def copyfrom(self, source_array):
197197
source_array = np.ascontiguousarray(
198198
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
199199
)
200-
if dtype.startswith("e2m1_float4"):
200+
if self.dtype.startswith("float4_e2m1fn") and self.dtype != "float4_e2m1fn":
201+
# float4_e2m1fn in numpy is not packed.
202+
# So we need to pack the input data when converting to vectorized float4_e2m1fn type.
201203
data_bits = source_array.view(dtype="uint8")
202204
if data_bits.size % 2:
203205
data_bits = np.pad(data_bits, (0, 1), mode="constant", constant_values=0)
@@ -261,22 +263,24 @@ def numpy(self):
261263
raise RuntimeError(
262264
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
263265
)
264-
if dtype == "e2m1_float4":
266+
if dtype == "float4_e2m1fn":
265267
if ml_dtypes is not None:
266268
dtype = ml_dtypes.float4_e2m1fn
267269
else:
268270
raise RuntimeError(
269-
"ml_dtypes is not installed, cannot convert e2m1_float4 array to numpy."
271+
"ml_dtypes is not installed, cannot convert float4_e2m1fn array to numpy."
270272
)
271273
np_arr = np.empty(shape, dtype=dtype)
272274
assert np_arr.flags["C_CONTIGUOUS"]
273275
data = np_arr.ctypes.data_as(ctypes.c_void_p)
274-
if old_dtype.startswith("e2m1_float4"):
276+
if old_dtype.startswith("float4_e2m1fn") and old_dtype != "float4_e2m1fn":
275277
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2)
276278
else:
277279
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
278280
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
279-
if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"):
281+
if old_dtype == "int4" or (
282+
old_dtype.startswith("float4_e2m1fn") and old_dtype != "float4_e2m1fn"
283+
):
280284
length = np_arr.size
281285
np_arr = np_arr.view("int8")
282286
np_arr_ret = np.empty((length,), dtype="int8")

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import functools
2020
import inspect
21-
from numbers import Integral
2221
import sys
22+
from numbers import Integral
2323
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2424

2525
# isort: off
@@ -29,8 +29,7 @@
2929

3030
import numpy as np # type: ignore
3131

32-
from tvm import tir
33-
from tvm import ir
32+
from tvm import ir, tir
3433
from tvm.ir import Type
3534
from tvm.ir.base import deprecated
3635
from tvm.runtime import String, convert, ndarray
@@ -1457,12 +1456,13 @@ def func(
14571456
e5m2_float8x32 = func_gen(("E5M2Float8x32"))
14581457
e5m2_float8x64 = func_gen(("E5M2Float8x64"))
14591458

1460-
e2m1_float4 = func_gen(("E2M1Float4"))
1461-
e2m1_float4x4 = func_gen(("E2M1Float4x4"))
1462-
e2m1_float4x8 = func_gen(("E2M1Float4x8"))
1463-
e2m1_float4x16 = func_gen(("E2M1Float4x16"))
1464-
e2m1_float4x32 = func_gen(("E2M1Float4x32"))
1465-
e2m1_float4x64 = func_gen(("E2M1Float4x64"))
1459+
float4_e2m1fn = func_gen(("Float4E2M1fn"))
1460+
float4_e2m1fnx2 = func_gen(("Float4E2M1fnx2"))
1461+
float4_e2m1fnx4 = func_gen(("Float4E2M1fnx4"))
1462+
float4_e2m1fnx8 = func_gen(("Float4E2M1fnx8"))
1463+
float4_e2m1fnx16 = func_gen(("Float4E2M1fnx16"))
1464+
float4_e2m1fnx32 = func_gen(("Float4E2M1fnx32"))
1465+
float4_e2m1fnx64 = func_gen(("Float4E2M1fnx64"))
14661466

14671467

14681468
# pylint: enable=invalid-name
@@ -2013,37 +2013,38 @@ def wrapped(*args, **kwargs):
20132013
"uint64x64",
20142014
"e4m3_float8",
20152015
"e5m2_float8",
2016-
"e2m1_float4",
2016+
"float4_e2m1fn",
20172017
"float16",
20182018
"float32",
20192019
"float64",
2020+
"float4_e2m1fnx2",
20202021
"e4m3_float8x4",
20212022
"e5m2_float8x4",
2022-
"e2m1_float4x4",
2023+
"float4_e2m1fnx4",
20232024
"float16x4",
20242025
"float32x4",
20252026
"float64x4",
20262027
"e4m3_float8x8",
20272028
"e5m2_float8x8",
2028-
"e2m1_float4x8",
2029+
"float4_e2m1fnx8",
20292030
"float16x8",
20302031
"float32x8",
20312032
"float64x8",
20322033
"e4m3_float8x16",
20332034
"e5m2_float8x16",
2034-
"e2m1_float4x16",
2035+
"float4_e2m1fnx16",
20352036
"float16x16",
20362037
"float32x16",
20372038
"float64x16",
20382039
"e4m3_float8x32",
20392040
"e5m2_float8x32",
2040-
"e2m1_float4x32",
2041+
"float4_e2m1fnx32",
20412042
"float16x32",
20422043
"float32x32",
20432044
"float64x32",
20442045
"e4m3_float8x64",
20452046
"e5m2_float8x64",
2046-
"e2m1_float4x64",
2047+
"float4_e2m1fnx64",
20472048
"float16x64",
20482049
"float32x64",
20492050
"float64x64",

src/runtime/ndarray.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ inline void VerifyDataType(DLDataType dtype) {
5454
return;
5555
else if (dtype.bits == 4 && dtype.code == kDLInt)
5656
return;
57-
else if (dtype.bits == 4 && dtype.code == DataType::kE2M1Float)
57+
else if (dtype.bits == 4 && dtype.code == DataType::kFloat4E2M1Fn)
5858
return;
5959
else
6060
ICHECK_EQ(dtype.bits % 8, 0);

src/script/ir_builder/tir/ir.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -757,8 +757,8 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float
757757
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
758758
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);
759759

760-
TVM_REGISTER_GLOBAL("script.ir_builder.tir.E2M1Float4").set_body_typed(E2M1Float4);
761-
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E2M1Float4", E2M1Float4);
760+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1fn").set_body_typed(Float4E2M1fn);
761+
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1fn", Float4E2M1fn);
762762

763763
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
764764
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);

src/target/llvm/codegen_llvm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
581581
}
582582
} else if (dtype.code() == DataType::kE4M3Float || dtype.code() == DataType::kE5M2Float) {
583583
etype = llvm::Type::getInt8Ty(*ctx);
584-
} else if (dtype.code() == DataType::kE2M1Float) {
584+
} else if (dtype.code() == DataType::kFloat4E2M1Fn) {
585585
etype = llvm::Type::getIntNTy(*ctx, 4);
586586
}
587587
if (!dtype.is_scalar()) {

src/target/source/codegen_c.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,11 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI
789789
}
790790
}
791791

792+
if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
793+
// A float4_e2m1fn element has 4 bits, which is an incomplete byte.
794+
// So we cannot vector load it.
795+
can_vector_load = false;
796+
}
792797
if (can_vector_load) {
793798
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
794799
HandleVolatileLoads(ref, op, os);
@@ -839,7 +844,8 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
839844
} else {
840845
arith::PVar<PrimExpr> base;
841846

842-
if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) {
847+
if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr) &&
848+
!value_dtype.is_float4_e2m1fn()) {
843849
std::string value = this->PrintExpr(op->value);
844850
this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value);
845851
} else {

0 commit comments

Comments
 (0)