Skip to content

Commit 7b88a74

Browse files
committed
[CUDA] FP4 cast and reinterpret support
Following up on a previous PR, this PR introduces the cast and reinterpret support between `__nv_fp4_e2m1` and other dtypes. This PR also makes sure that the cast and reinterpret support vectorize.
1 parent e35a424 commit 7b88a74

File tree

7 files changed

+264
-17
lines changed

7 files changed

+264
-17
lines changed

python/tvm/runtime/ndarray.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,9 @@ def copyfrom(self, source_array):
197197
source_array = np.ascontiguousarray(
198198
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
199199
)
200-
if dtype.startswith("e2m1_float4"):
200+
if self.dtype.startswith("e2m1_float4") and self.dtype != "e2m1_float4":
201+
# e2m1_float4 in numpy is not packed.
202+
# So we need to pack the input data when converting to vectorized e2m1_float4 type.
201203
data_bits = source_array.view(dtype="uint8")
202204
if data_bits.size % 2:
203205
data_bits = np.pad(data_bits, (0, 1), mode="constant", constant_values=0)
@@ -271,12 +273,14 @@ def numpy(self):
271273
np_arr = np.empty(shape, dtype=dtype)
272274
assert np_arr.flags["C_CONTIGUOUS"]
273275
data = np_arr.ctypes.data_as(ctypes.c_void_p)
274-
if old_dtype.startswith("e2m1_float4"):
276+
if old_dtype.startswith("e2m1_float4") and old_dtype != "e2m1_float4":
275277
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2)
276278
else:
277279
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
278280
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
279-
if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"):
281+
if old_dtype == "int4" or (
282+
old_dtype.startswith("e2m1_float4") and old_dtype != "e2m1_float4"
283+
):
280284
length = np_arr.size
281285
np_arr = np_arr.view("int8")
282286
np_arr_ret = np.empty((length,), dtype="int8")

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,6 +1458,7 @@ def func(
14581458
e5m2_float8x64 = func_gen(("E5M2Float8x64"))
14591459

14601460
e2m1_float4 = func_gen(("E2M1Float4"))
1461+
e2m1_float4x2 = func_gen(("E2M1Float4x2"))
14611462
e2m1_float4x4 = func_gen(("E2M1Float4x4"))
14621463
e2m1_float4x8 = func_gen(("E2M1Float4x8"))
14631464
e2m1_float4x16 = func_gen(("E2M1Float4x16"))
@@ -2017,6 +2018,7 @@ def wrapped(*args, **kwargs):
20172018
"float16",
20182019
"float32",
20192020
"float64",
2021+
"e2m1_float4x2",
20202022
"e4m3_float8x4",
20212023
"e5m2_float8x4",
20222024
"e2m1_float4x4",

src/target/source/codegen_c.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,11 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI
789789
}
790790
}
791791

792+
if (value_dtype.is_e2m1_float4() && lanes != 1) {
793+
// A e2m1_float4 element has 4 bits, which is an incomplete byte.
794+
// So we cannot vector load it.
795+
can_vector_load = false;
796+
}
792797
if (can_vector_load) {
793798
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
794799
HandleVolatileLoads(ref, op, os);
@@ -839,7 +844,8 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
839844
} else {
840845
arith::PVar<PrimExpr> base;
841846

842-
if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) {
847+
if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr) &&
848+
!value_dtype.is_e2m1_float4()) {
843849
std::string value = this->PrintExpr(op->value);
844850
this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value);
845851
} else {

src/target/source/codegen_cuda.cc

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ std::string GetFP4Type(DataType type) {
8282
} else if (lanes == 4) {
8383
vec = "x4";
8484
} else {
85-
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8";
85+
LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for FP8";
8686
}
8787
stream << "__nv_fp4";
8888
std::string suffix;
@@ -196,7 +196,7 @@ std::string CodeGenCUDA::Finish() {
196196
decl_stream << "#include <cuda_fp4.h>\n";
197197
decl_stream << "#endif\n\n";
198198
}
199-
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
199+
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_, enable_fp4_);
200200

201201
if (enable_warp_shuffle_) {
202202
decl_stream << _cuda_warp_intrinsic_util;
@@ -597,6 +597,9 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
597597
}
598598
ICHECK(!type_name.empty());
599599
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
600+
} else if (t.is_e2m1_float4()) {
601+
os << "([](__nv_fp4_storage_t v) { __nv_fp4_e2m1 t; t.__x = v; return t; })((" << vec
602+
<< ".__x >> " << i * 4 << ") & 0xF)";
600603
} else {
601604
os << vec << "." << access[i];
602605
}
@@ -1036,8 +1039,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
10361039
var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
10371040

10381041
os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
1039-
os << dst << "[" + this->PrintExpr(dst_ind) + "]"
1040-
<< " = " << src << "[" << src_offset << " + local_id];\n";
1042+
os << dst << "[" + this->PrintExpr(dst_ind) + "] = " << src << "[" << src_offset
1043+
<< " + local_id];\n";
10411044
os << "}\n";
10421045

10431046
} else if (op->op.same_as(builtin::mma_fill())) {
@@ -1155,6 +1158,82 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
11551158
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)"
11561159
<< guard << ")\n";
11571160
stream << ");\n";
1161+
} else if (op->op.same_as(builtin::reinterpret())) {
1162+
DataType tgt_dtype = op->dtype;
1163+
DataType src_dtype = op->args[0]->dtype;
1164+
PrimExpr value = op->args[0];
1165+
1166+
// Handle e2m1_float4 reinterpret
1167+
if (!src_dtype.is_e2m1_float4() && !tgt_dtype.is_e2m1_float4()) {
1168+
return CodeGenC::VisitExpr_(op, os);
1169+
}
1170+
if (src_dtype == tgt_dtype ||
1171+
tgt_dtype.lanes() * tgt_dtype.bits() == src_dtype.lanes() * src_dtype.bits()) {
1172+
return CodeGenC::VisitExpr_(op, os);
1173+
}
1174+
CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes())
1175+
<< "E2M1 float4 reinterpret expects source and target to have the same number of lanes. "
1176+
<< "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
1177+
CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes())
1178+
<< "E2M1 float4 reinterpret expects source and target to have the same number of bytes. "
1179+
<< "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
1180+
1181+
int lanes = tgt_dtype.lanes();
1182+
1183+
int ssa_scope = BeginScope();
1184+
if (lanes == 1) {
1185+
// The case of lane=1 is same as the normal reinterpret,
1186+
// except that we allow the src and dst dtype to have different number of bits.
1187+
std::string rhs = SSAGetID(PrintExpr(value), src_dtype);
1188+
os << "(*(";
1189+
this->PrintType(tgt_dtype, os);
1190+
os << " *)(&(" << rhs << ")))";
1191+
} else if (lanes == 2) {
1192+
if (tgt_dtype.is_e2m1_float4()) {
1193+
// We view the source as an uint16, and then extract bits of two fp4 numbers,
1194+
// and finally reinterpret the result as fp4x2.
1195+
value = tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value});
1196+
tir::Var temp_var("temp_var", DataType::UInt(16));
1197+
value = tir::Let(
1198+
temp_var, value,
1199+
tir::Cast(DataType::UInt(8), (temp_var & IntImm(DataType::UInt(16), 0xF)) |
1200+
((temp_var >> 4) & IntImm(DataType::UInt(16), 0xF0))));
1201+
} else {
1202+
value = tir::Cast(DataType::UInt(16),
1203+
tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value}));
1204+
tir::Var temp_var("temp_var", DataType::UInt(16));
1205+
value = tir::Let(temp_var, value,
1206+
(temp_var & IntImm(DataType::UInt(16), 0xF)) |
1207+
((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4));
1208+
}
1209+
os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
1210+
} else if (lanes == 4) {
1211+
if (tgt_dtype.is_e2m1_float4()) {
1212+
// We view the source as an uint32, and then extract bits of four fp4 numbers,
1213+
// and finally reinterpret the result as fp4x4.
1214+
value = tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value});
1215+
tir::Var temp_var("temp_var", DataType::UInt(32));
1216+
value = tir::Let(temp_var, value,
1217+
tir::Cast(DataType::UInt(16),
1218+
(temp_var & IntImm(DataType::UInt(32), 0xF)) |
1219+
((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) |
1220+
((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) |
1221+
((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000))));
1222+
} else {
1223+
value = tir::Cast(DataType::UInt(32),
1224+
tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value}));
1225+
tir::Var temp_var("temp_var", DataType::UInt(32));
1226+
value = tir::Let(temp_var, value,
1227+
(temp_var & IntImm(DataType::UInt(32), 0xF)) |
1228+
((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) |
1229+
((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) |
1230+
((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12));
1231+
}
1232+
os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
1233+
} else {
1234+
LOG(FATAL) << "Invalid number of lanes for e2m1_float4 reinterpret: " << lanes;
1235+
}
1236+
EndScope(ssa_scope);
11581237
} else {
11591238
CodeGenC::VisitExpr_(op, os);
11601239
}

src/target/source/literal/cuda_half_t.h

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,9 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"(
385385
386386
)";
387387

388-
void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8) {
389-
if (enable_fp16 || enable_fp8) {
388+
void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8,
389+
bool enable_fp4) {
390+
if (enable_fp16 || enable_fp8 || enable_fp4) {
390391
stream << R"(
391392
struct __align__(8) half4 {
392393
__half x, y, z, w;
@@ -455,13 +456,47 @@ struct __align__(8) half4 {
455456
result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
456457
return result;
457458
}
459+
)";
460+
}
461+
if (enable_fp4) {
462+
stream << R"(
463+
__host__ __device__ explicit half4(const __nv_fp4x4_e2m1& fp4x4) {
464+
__nv_fp4x2_storage_t lo_part, hi_part;
465+
lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
466+
hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
467+
__half2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
468+
__half2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1));
469+
x = reinterpret_cast<__half*>(&lo_half2)[0];
470+
y = reinterpret_cast<__half*>(&lo_half2)[1];
471+
z = reinterpret_cast<__half*>(&hi_half2)[0];
472+
w = reinterpret_cast<__half*>(&hi_half2)[1];
473+
}
474+
__host__ __device__ explicit operator __nv_fp4x4_e2m1() const {
475+
__half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
476+
__half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
477+
return __nv_fp4x4_e2m1(lo_half2, hi_half2);
478+
}
458479
)";
459480
}
460481
stream << R"(
461482
};
462483
__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
463484
return half4(x, y, z, w);
464485
}
486+
)";
487+
}
488+
if (enable_fp4) {
489+
stream << R"(
490+
__device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1 y) {
491+
__nv_fp4x2_e2m1 result;
492+
result.__x = (x.__x) | (y.__x << 4);
493+
return result;
494+
}
495+
__device__ __nv_fp4x4_e2m1 make___nv_fp4x4_e2m1(__nv_fp4_e2m1 a, __nv_fp4_e2m1 b, __nv_fp4_e2m1 c, __nv_fp4_e2m1 d) {
496+
__nv_fp4x4_e2m1 result;
497+
result.__x = (static_cast<__nv_fp4x4_storage_t>(a.__x)) | (static_cast<__nv_fp4x4_storage_t>(b.__x) << 4) | (static_cast<__nv_fp4x4_storage_t>(c.__x) << 8) | (static_cast<__nv_fp4x4_storage_t>(d.__x) << 12);
498+
return result;
499+
}
465500
)";
466501
}
467502
}

src/tir/op/op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,10 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) {
425425
PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
426426
if (value.dtype() == t) return value;
427427
if (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()) {
428-
ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes())
429-
<< "Bitcast requires size match " << t << " vs " << value.dtype();
428+
ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes() ||
429+
((value.dtype().is_e2m1_float4() || t.is_e2m1_float4()) &&
430+
value.dtype().bytes() * value.dtype().lanes() == 1 && t.bytes() * t.lanes()))
431+
<< "Reinterpret requires size match " << t << " vs " << value.dtype();
430432
}
431433
return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
432434
}

0 commit comments

Comments
 (0)