@@ -82,7 +82,7 @@ std::string GetFP4Type(DataType type) {
8282 } else if (lanes == 4 ) {
8383 vec = " x4" ;
8484 } else {
85- LOG (FATAL) << " Only support scalar and vector types of width (2, 4, 8 ) for FP8" ;
85+ LOG (FATAL) << " Only support scalar and vector types of width (2, 4) for FP8" ;
8686 }
8787 stream << " __nv_fp4" ;
8888 std::string suffix;
@@ -196,7 +196,7 @@ std::string CodeGenCUDA::Finish() {
196196 decl_stream << " #include <cuda_fp4.h>\n " ;
197197 decl_stream << " #endif\n\n " ;
198198 }
199- declare_vector_type_extensions (decl_stream, enable_fp16_, enable_fp8_);
199+ declare_vector_type_extensions (decl_stream, enable_fp16_, enable_fp8_, enable_fp4_ );
200200
201201 if (enable_warp_shuffle_) {
202202 decl_stream << _cuda_warp_intrinsic_util;
@@ -597,6 +597,9 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
597597 }
598598 ICHECK (!type_name.empty ());
599599 os << " ((" << type_name << " 2*)(&(" << vec << " ." << access[i / 2 ] << " )))->" << access[i % 2 ];
600+ } else if (t.is_e2m1_float4 ()) {
601+ os << " ([](__nv_fp4_storage_t v) { __nv_fp4_e2m1 t; t.__x = v; return t; })((" << vec
602+ << " .__x >> " << i * 4 << " ) & 0xF)" ;
600603 } else {
601604 os << vec << " ." << access[i];
602605 }
@@ -1036,8 +1039,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
10361039 var_idmap_[inverse_index_map->initial_indices [1 ].get ()] = " local_id" ;
10371040
10381041 os << " for (int local_id = 0; local_id < 8; ++local_id) {\n " ;
1039- os << dst << " [" + this ->PrintExpr (dst_ind) + " ]"
1040- << " = " << src << " [ " << src_offset << " + local_id];\n " ;
1042+ os << dst << " [" + this ->PrintExpr (dst_ind) + " ] = " << src << " [ " << src_offset
1043+ << " + local_id];\n " ;
10411044 os << " }\n " ;
10421045
10431046 } else if (op->op .same_as (builtin::mma_fill ())) {
@@ -1155,6 +1158,82 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
11551158 stream << " : \" l\" ((void*)(" << global_buffer << " +" << global_addr << " )), \" r\" ((int)"
11561159 << guard << " )\n " ;
11571160 stream << " );\n " ;
1161+ } else if (op->op .same_as (builtin::reinterpret ())) {
1162+ DataType tgt_dtype = op->dtype ;
1163+ DataType src_dtype = op->args [0 ]->dtype ;
1164+ PrimExpr value = op->args [0 ];
1165+
1166+ // Handle e2m1_float4 reinterpret
1167+ if (!src_dtype.is_e2m1_float4 () && !tgt_dtype.is_e2m1_float4 ()) {
1168+ return CodeGenC::VisitExpr_ (op, os);
1169+ }
1170+ if (src_dtype == tgt_dtype ||
1171+ tgt_dtype.lanes () * tgt_dtype.bits () == src_dtype.lanes () * src_dtype.bits ()) {
1172+ return CodeGenC::VisitExpr_ (op, os);
1173+ }
1174+ CHECK_EQ (tgt_dtype.lanes (), src_dtype.lanes ())
1175+ << " E2M1 float4 reinterpret expects source and target to have the same number of lanes. "
1176+ << " Source dtype: " << src_dtype << " , Target dtype: " << tgt_dtype;
1177+ CHECK_EQ (tgt_dtype.bytes (), src_dtype.bytes ())
1178+ << " E2M1 float4 reinterpret expects source and target to have the same number of bytes. "
1179+ << " Source dtype: " << src_dtype << " , Target dtype: " << tgt_dtype;
1180+
1181+ int lanes = tgt_dtype.lanes ();
1182+
1183+ int ssa_scope = BeginScope ();
1184+ if (lanes == 1 ) {
1185+ // The case of lane=1 is same as the normal reinterpret,
1186+ // except that we allow the src and dst dtype to have different number of bits.
1187+ std::string rhs = SSAGetID (PrintExpr (value), src_dtype);
1188+ os << " (*(" ;
1189+ this ->PrintType (tgt_dtype, os);
1190+ os << " *)(&(" << rhs << " )))" ;
1191+ } else if (lanes == 2 ) {
1192+ if (tgt_dtype.is_e2m1_float4 ()) {
1193+ // We view the source as an uint16, and then extract bits of two fp4 numbers,
1194+ // and finally reinterpret the result as fp4x2.
1195+ value = tir::Call (DataType::UInt (16 ), tir::builtin::reinterpret (), {value});
1196+ tir::Var temp_var (" temp_var" , DataType::UInt (16 ));
1197+ value = tir::Let (
1198+ temp_var, value,
1199+ tir::Cast (DataType::UInt (8 ), (temp_var & IntImm (DataType::UInt (16 ), 0xF )) |
1200+ ((temp_var >> 4 ) & IntImm (DataType::UInt (16 ), 0xF0 ))));
1201+ } else {
1202+ value = tir::Cast (DataType::UInt (16 ),
1203+ tir::Call (DataType::UInt (8 ), tir::builtin::reinterpret (), {value}));
1204+ tir::Var temp_var (" temp_var" , DataType::UInt (16 ));
1205+ value = tir::Let (temp_var, value,
1206+ (temp_var & IntImm (DataType::UInt (16 ), 0xF )) |
1207+ ((temp_var & IntImm (DataType::UInt (16 ), 0xF0 )) << 4 ));
1208+ }
1209+ os << PrintExpr (tir::Call (tgt_dtype, tir::builtin::reinterpret (), {value}));
1210+ } else if (lanes == 4 ) {
1211+ if (tgt_dtype.is_e2m1_float4 ()) {
1212+ // We view the source as an uint32, and then extract bits of four fp4 numbers,
1213+ // and finally reinterpret the result as fp4x4.
1214+ value = tir::Call (DataType::UInt (32 ), tir::builtin::reinterpret (), {value});
1215+ tir::Var temp_var (" temp_var" , DataType::UInt (32 ));
1216+ value = tir::Let (temp_var, value,
1217+ tir::Cast (DataType::UInt (16 ),
1218+ (temp_var & IntImm (DataType::UInt (32 ), 0xF )) |
1219+ ((temp_var >> 4 ) & IntImm (DataType::UInt (32 ), 0xF0 )) |
1220+ ((temp_var >> 8 ) & IntImm (DataType::UInt (32 ), 0xF00 )) |
1221+ ((temp_var >> 12 ) & IntImm (DataType::UInt (32 ), 0xF000 ))));
1222+ } else {
1223+ value = tir::Cast (DataType::UInt (32 ),
1224+ tir::Call (DataType::UInt (16 ), tir::builtin::reinterpret (), {value}));
1225+ tir::Var temp_var (" temp_var" , DataType::UInt (32 ));
1226+ value = tir::Let (temp_var, value,
1227+ (temp_var & IntImm (DataType::UInt (32 ), 0xF )) |
1228+ ((temp_var & IntImm (DataType::UInt (32 ), 0xF0 )) << 4 ) |
1229+ ((temp_var & IntImm (DataType::UInt (32 ), 0xF00 )) << 8 ) |
1230+ ((temp_var & IntImm (DataType::UInt (32 ), 0xF000 )) << 12 ));
1231+ }
1232+ os << PrintExpr (tir::Call (tgt_dtype, tir::builtin::reinterpret (), {value}));
1233+ } else {
1234+ LOG (FATAL) << " Invalid number of lanes for e2m1_float4 reinterpret: " << lanes;
1235+ }
1236+ EndScope (ssa_scope);
11581237 } else {
11591238 CodeGenC::VisitExpr_ (op, os);
11601239 }
0 commit comments