@@ -58,6 +58,7 @@ class DataType {
5858 kBFloat = kDLBfloat ,
5959 kE4M3Float = 6U ,
6060 kE5M2Float = 7U ,
61+ kE2M1Float = 8U ,
6162 kCustomBegin = 129
6263 };
6364 /* ! \brief default constructor */
@@ -87,6 +88,9 @@ class DataType {
8788 if (code == kE4M3Float || code == kE5M2Float ) {
8889 ICHECK_EQ (bits, 8 );
8990 }
91+ if (code == kE2M1Float ) {
92+ ICHECK_EQ (bits, 4 );
93+ }
9094 }
9195 /* ! \return The type code. */
9296 int code () const { return static_cast <int >(data_.code ); }
@@ -126,9 +130,13 @@ class DataType {
126130 code () == DataType::kE5M2Float ) &&
127131 bits () == 8 ;
128132 }
133+ /* ! \return whether type is a float4 type. */
134+ bool is_float4 () const { return code () == DataType::kE2M1Float && bits () == 4 ; }
129135 bool is_e4m3_float8 () const { return (code () == DataType::kE4M3Float && bits () == 8 ); }
130136
131137 bool is_e5m2_float8 () const { return (code () == DataType::kE5M2Float && bits () == 8 ); }
138+
139+ bool is_e2m1_float4 () const { return (code () == DataType::kE2M1Float && bits () == 4 ); }
132140 /* ! \return whether type is a float16 type. */
133141 bool is_float16 () const { return is_float () && bits () == 16 ; }
134142 /* ! \return whether type is a bfloat16 type. */
@@ -253,6 +261,12 @@ class DataType {
253261 * \return The constructed data type.
254262 */
255263 static DataType NVFloat8E5M2 (int lanes = 1 ) { return DataType (kE5M2Float , 8 , lanes); }
264+ /* !
265+ * \brief Construct NV float4 e2m1 datatype.
266+ * \param lanes The number of lanes
267+ * \return The constructed data type.
268+ */
269+ static DataType NVFloat4E2M1 (int lanes = 1 ) { return DataType (kE2M1Float , 4 , lanes); }
256270 /* !
257271 * \brief Construct a bool type.
258272 * \param lanes The number of lanes.
@@ -299,7 +313,7 @@ inline int GetVectorBytes(DataType dtype) {
299313 int data_bits = dtype.bits () * dtype.lanes ();
300314 // allow bool to exist
301315 if (dtype == DataType::Bool () || dtype == DataType::Int (4 ) || dtype == DataType::UInt (4 ) ||
302- dtype == DataType::Int (1 )) {
316+ dtype == DataType::Int (1 ) || dtype == DataType::NVFloat4E2M1 () ) {
303317 return 1 ;
304318 }
305319 ICHECK_EQ (data_bits % 8 , 0U ) << " Need to load/store by multiple of bytes" ;
@@ -385,6 +399,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
385399 return " e4m3_float" ;
386400 case DataType::kE5M2Float :
387401 return " e5m2_float" ;
402+ case DataType::kE2M1Float :
403+ return " e2m1_float" ;
388404 default :
389405 LOG (FATAL) << " unknown type_code=" << static_cast <int >(type_code);
390406 }
@@ -466,6 +482,10 @@ inline DLDataType String2DLDataType(std::string s) {
466482 t.code = DataType::kE5M2Float ;
467483 t.bits = 8 ;
468484 scan = s.c_str () + 10 ;
485+ } else if (s.substr (0 , 10 ) == " e2m1_float" ) {
486+ t.code = DataType::kE2M1Float ;
487+ t.bits = 4 ;
488+ scan = s.c_str () + 10 ;
469489 } else if (s.substr (0 , 6 ) == " custom" ) {
470490 t.code = ParseCustomDatatype (s, &scan);
471491 } else {
0 commit comments