Skip to content

Commit 7bedfeb

Browse files
jinhongyiiyongwww
andauthored
[Codegen] FP4 support (#17630)
* fp4 * fix test * fix lint * fix * Test with manually built images --------- Co-authored-by: Yong Wu <[email protected]>
1 parent 61f6e7f commit 7bedfeb

File tree

21 files changed

+267
-19
lines changed

21 files changed

+267
-19
lines changed

ci/jenkins/docker-images.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
# This data file is read during when Jenkins runs job to determine docker images.
1919
[jenkins]
2020
ci_arm: tlcpack/ci-arm:20250226-223225-63bc315f
21-
ci_cpu: tlcpack/ci_cpu:20250226-223225-63bc315f
22-
ci_gpu: tlcpack/ci-gpu:20250226-223225-63bc315f
21+
ci_cpu: tlcpack/ci_cpu:20250226-223225-63bc315f_patch
22+
ci_gpu: tlcpack/ci-gpu:20250226-223225-63bc315f_patch
2323
ci_hexagon: tlcpack/ci-hexagon:20250226-223225-63bc315f
2424
ci_i386: tlcpack/ci-i386:20250226-223225-63bc315f
2525
ci_lint: tlcpack/ci-lint:20250226-223225-63bc315f

ci/jenkins/unity_jenkinsfile.groovy

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
3131

3232
// NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. -->
33-
ci_gpu = 'tlcpack/ci-gpu:20250226-223225-63bc315f'
34-
ci_cpu = 'tlcpack/ci-cpu:20250226-223225-63bc315f'
33+
ci_gpu = 'tlcpack/ci-gpu:20250226-223225-63bc315f_patch'
34+
ci_cpu = 'tlcpack/ci-cpu:20250226-223225-63bc315f_patch'
3535
// <--- End of regex-scanned config.
3636

3737
// Parameters to allow overriding (in Jenkins UI), the images

include/tvm/runtime/data_type.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class DataType {
5858
kBFloat = kDLBfloat,
5959
kE4M3Float = 6U,
6060
kE5M2Float = 7U,
61+
kE2M1Float = 8U,
6162
kCustomBegin = 129
6263
};
6364
/*! \brief default constructor */
@@ -87,6 +88,9 @@ class DataType {
8788
if (code == kE4M3Float || code == kE5M2Float) {
8889
ICHECK_EQ(bits, 8);
8990
}
91+
if (code == kE2M1Float) {
92+
ICHECK_EQ(bits, 4);
93+
}
9094
}
9195
/*! \return The type code. */
9296
int code() const { return static_cast<int>(data_.code); }
@@ -126,9 +130,13 @@ class DataType {
126130
code() == DataType::kE5M2Float) &&
127131
bits() == 8;
128132
}
133+
/*! \return whether type is a float4 type. */
134+
bool is_float4() const { return code() == DataType::kE2M1Float && bits() == 4; }
129135
bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); }
130136

131137
bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); }
138+
139+
bool is_e2m1_float4() const { return (code() == DataType::kE2M1Float && bits() == 4); }
132140
/*! \return whether type is a float16 type. */
133141
bool is_float16() const { return is_float() && bits() == 16; }
134142
/*! \return whether type is a bfloat16 type. */
@@ -253,6 +261,12 @@ class DataType {
253261
* \return The constructed data type.
254262
*/
255263
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
264+
/*!
265+
* \brief Construct NV float4 e2m1 datatype.
266+
* \param lanes The number of lanes
267+
* \return The constructed data type.
268+
*/
269+
static DataType NVFloat4E2M1(int lanes = 1) { return DataType(kE2M1Float, 4, lanes); }
256270
/*!
257271
* \brief Construct a bool type.
258272
* \param lanes The number of lanes.
@@ -299,7 +313,7 @@ inline int GetVectorBytes(DataType dtype) {
299313
int data_bits = dtype.bits() * dtype.lanes();
300314
// allow bool to exist
301315
if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
302-
dtype == DataType::Int(1)) {
316+
dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1()) {
303317
return 1;
304318
}
305319
ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
@@ -385,6 +399,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
385399
return "e4m3_float";
386400
case DataType::kE5M2Float:
387401
return "e5m2_float";
402+
case DataType::kE2M1Float:
403+
return "e2m1_float";
388404
default:
389405
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
390406
}
@@ -466,6 +482,10 @@ inline DLDataType String2DLDataType(std::string s) {
466482
t.code = DataType::kE5M2Float;
467483
t.bits = 8;
468484
scan = s.c_str() + 10;
485+
} else if (s.substr(0, 10) == "e2m1_float") {
486+
t.code = DataType::kE2M1Float;
487+
t.bits = 4;
488+
scan = s.c_str() + 10;
469489
} else if (s.substr(0, 6) == "custom") {
470490
t.code = ParseCustomDatatype(s, &scan);
471491
} else {

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,8 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
505505
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, DataType::NVFloat8E4M3);
506506
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, DataType::NVFloat8E5M2);
507507

508+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E2M1Float4, DataType::NVFloat4E2M1);
509+
508510
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
509511
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
510512

include/tvm/tir/op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
940940
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
941941
}
942942
}
943-
if (t.is_float() || t.is_bfloat16() || t.is_float8())
943+
if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float4())
944944
return FloatImm(t, static_cast<double>(value), span);
945945
// For now, we store const scalar values of custom datatypes within doubles; later, during the
946946
// datatypes lowering pass, we will lower the value to its true representation in the format

python/tvm/_ffi/runtime_ctypes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class DataTypeCode(object):
6868
BFLOAT = 4
6969
E4M3Float = 6
7070
E5M2Float = 7
71+
E2M1Float = 8
7172

7273

7374
class DataType(ctypes.Structure):
@@ -82,6 +83,7 @@ class DataType(ctypes.Structure):
8283
DataTypeCode.BFLOAT: "bfloat",
8384
DataTypeCode.E4M3Float: "e4m3_float",
8485
DataTypeCode.E5M2Float: "e5m2_float",
86+
DataTypeCode.E2M1Float: "e2m1_float",
8587
}
8688
NUMPY2STR = {
8789
np.dtype(np.bool_): "bool",
@@ -112,6 +114,7 @@ class DataType(ctypes.Structure):
112114
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
113115
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1},
114116
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1},
117+
"e2m1_float4": {"type_code": DataTypeCode.E2M1Float, "bits": 4, "lanes": 1},
115118
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
116119
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
117120
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
@@ -168,6 +171,9 @@ def __init__(self, type_str):
168171
elif head.startswith("e5m2_float"):
169172
self.type_code = DataTypeCode.E5M2Float
170173
head = head[10:]
174+
elif head.startswith("e2m1_float"):
175+
self.type_code = DataTypeCode.E2M1Float
176+
head = head[10:]
171177
elif head.startswith("custom"):
172178
# pylint: disable=import-outside-toplevel
173179
import tvm.runtime._ffi_api
@@ -232,6 +238,7 @@ def itemsize(self):
232238
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
233239
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
234240
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
241+
DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "e2m1_float4"
235242

236243
RPC_SESS_MASK = 128
237244

python/tvm/contrib/nvcc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,19 @@ def have_fp8(compute_version):
445445
if major >= 9:
446446
return True
447447
return False
448+
449+
450+
@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp4")
451+
def have_fp4(compute_version):
452+
"""Whether fp4 support is provided in the specified compute capability or not
453+
454+
Parameters
455+
----------
456+
compute_version : str
457+
GPU capability
458+
"""
459+
major, minor = parse_compute_version(compute_version)
460+
# fp4 is suppored in Blackwell (10.0) or later architectures.
461+
if major == 10 and minor == 0:
462+
return True
463+
return False

python/tvm/runtime/ndarray.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,13 @@ def copyfrom(self, source_array):
197197
source_array = np.ascontiguousarray(
198198
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
199199
)
200+
if dtype.startswith("e2m1_float4"):
201+
data_bits = source_array.view(dtype="uint8")
202+
if data_bits.size % 2:
203+
data_bits = np.pad(data_bits, (0, 1), mode="constant", constant_values=0)
204+
data_bits = data_bits.reshape(-1, 2)
205+
packed = ((data_bits[:, 0] & 0x0F) << 4) | (data_bits[:, 1] & 0x0F)
206+
source_array = packed.astype(np.int8)
200207
assert source_array.flags["C_CONTIGUOUS"]
201208
data = source_array.ctypes.data_as(ctypes.c_void_p)
202209
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
@@ -254,20 +261,32 @@ def numpy(self):
254261
raise RuntimeError(
255262
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
256263
)
264+
if dtype == "e2m1_float4":
265+
if ml_dtypes is not None:
266+
dtype = ml_dtypes.float4_e2m1fn
267+
else:
268+
raise RuntimeError(
269+
"ml_dtypes is not installed, cannot convert e2m1_float4 array to numpy."
270+
)
257271
np_arr = np.empty(shape, dtype=dtype)
258272
assert np_arr.flags["C_CONTIGUOUS"]
259273
data = np_arr.ctypes.data_as(ctypes.c_void_p)
260-
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
274+
if old_dtype.startswith("e2m1_float4"):
275+
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2)
276+
else:
277+
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
261278
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
262-
if old_dtype == "int4":
279+
if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"):
263280
length = np_arr.size
281+
np_arr = np_arr.view("int8")
264282
np_arr_ret = np.empty((length,), dtype="int8")
265283
np_arr = np_arr.reshape((length,))
266284
old_index = np.bitwise_and(np_arr, 0x0F)
267285
even_index = np.bitwise_and(np_arr >> 4, 0x0F)
268286
np_arr_ret[1::2] = old_index[0 : length // 2]
269287
np_arr_ret[0::2] = even_index[0 : length // 2]
270-
return np_arr_ret.reshape(shape)
288+
return np_arr_ret.reshape(shape).view(dtype)
289+
271290
return np_arr
272291

273292
def copyto(self, target, mem_scope=None):

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,6 +1457,14 @@ def func(
14571457
e5m2_float8x32 = func_gen(("E5M2Float8x32"))
14581458
e5m2_float8x64 = func_gen(("E5M2Float8x64"))
14591459

1460+
e2m1_float4 = func_gen(("E2M1Float4"))
1461+
e2m1_float4x4 = func_gen(("E2M1Float4x4"))
1462+
e2m1_float4x8 = func_gen(("E2M1Float4x8"))
1463+
e2m1_float4x16 = func_gen(("E2M1Float4x16"))
1464+
e2m1_float4x32 = func_gen(("E2M1Float4x32"))
1465+
e2m1_float4x64 = func_gen(("E2M1Float4x64"))
1466+
1467+
14601468
# pylint: enable=invalid-name
14611469

14621470

@@ -2005,31 +2013,37 @@ def wrapped(*args, **kwargs):
20052013
"uint64x64",
20062014
"e4m3_float8",
20072015
"e5m2_float8",
2016+
"e2m1_float4",
20082017
"float16",
20092018
"float32",
20102019
"float64",
20112020
"e4m3_float8x4",
20122021
"e5m2_float8x4",
2022+
"e2m1_float4x4",
20132023
"float16x4",
20142024
"float32x4",
20152025
"float64x4",
20162026
"e4m3_float8x8",
20172027
"e5m2_float8x8",
2028+
"e2m1_float4x8",
20182029
"float16x8",
20192030
"float32x8",
20202031
"float64x8",
20212032
"e4m3_float8x16",
20222033
"e5m2_float8x16",
2034+
"e2m1_float4x16",
20232035
"float16x16",
20242036
"float32x16",
20252037
"float64x16",
20262038
"e4m3_float8x32",
20272039
"e5m2_float8x32",
2040+
"e2m1_float4x32",
20282041
"float16x32",
20292042
"float32x32",
20302043
"float64x32",
20312044
"e4m3_float8x64",
20322045
"e5m2_float8x64",
2046+
"e2m1_float4x64",
20332047
"float16x64",
20342048
"float32x64",
20352049
"float64x64",

src/ir/expr.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ TVM_REGISTER_NODE_TYPE(IntImmNode);
110110
FloatImm::FloatImm(DataType dtype, double value, Span span) {
111111
ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";
112112

113-
ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() ||
113+
ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float4() ||
114114
dtype.code() >= DataType::kCustomBegin)
115115
<< "ValueError: FloatImm supports only float, but " << dtype << " was supplied.";
116116

@@ -137,6 +137,11 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) {
137137
<< dtype;
138138
ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << " exceeds maximum of "
139139
<< dtype;
140+
} else if (dtype.is_float4()) {
141+
ICHECK_GE(value, -support::kMaxE2M1)
142+
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
143+
ICHECK_LE(value, support::kMaxE2M1)
144+
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
140145
}
141146
}
142147
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();

0 commit comments

Comments
 (0)