@@ -58,6 +58,7 @@ class DataType {
5858 kBFloat = kDLBfloat ,
5959 kE4M3Float = 6U ,
6060 kE5M2Float = 7U ,
61+ kE2M1Float = 8U ,
6162 kCustomBegin = 129
6263 };
6364 /* ! \brief default constructor */
@@ -87,6 +88,9 @@ class DataType {
8788 if (code == kE4M3Float || code == kE5M2Float ) {
8889 ICHECK_EQ (bits, 8 );
8990 }
91+ if (code == kE2M1Float ) {
92+ ICHECK_EQ (bits, 4 );
93+ }
9094 }
9195 /* ! \return The type code. */
9296 int code () const { return static_cast <int >(data_.code ); }
@@ -126,9 +130,15 @@ class DataType {
126130 code () == DataType::kE5M2Float ) &&
127131 bits () == 8 ;
128132 }
133+ /* ! \return whether type is a float4 type. */
134+ bool is_float4 () const {
135+ return code () == DataType::kE2M1Float && bits () == 4 ;
136+ }
129137 bool is_e4m3_float8 () const { return (code () == DataType::kE4M3Float && bits () == 8 ); }
130138
131139 bool is_e5m2_float8 () const { return (code () == DataType::kE5M2Float && bits () == 8 ); }
140+
141+ bool is_e2m1_float4 () const { return (code () == DataType::kE2M1Float && bits () == 4 ); }
132142 /* ! \return whether type is a float16 type. */
133143 bool is_float16 () const { return is_float () && bits () == 16 ; }
134144 /* ! \return whether type is a bfloat16 type. */
@@ -253,6 +263,12 @@ class DataType {
253263 * \return The constructed data type.
254264 */
255265 static DataType NVFloat8E5M2 (int lanes = 1 ) { return DataType (kE5M2Float , 8 , lanes); }
266+ /* !
267+ * \brief Construct NV float4 e2m1 datatype.
268+ * \param lanes The number of lanes
269+ * \return The constructed data type.
270+ */
271+ static DataType NVFloat4E2M1 (int lanes = 1 ) { return DataType (kE2M1Float , 4 , lanes); }
256272 /* !
257273 * \brief Construct a bool type.
258274 * \param lanes The number of lanes.
@@ -299,7 +315,7 @@ inline int GetVectorBytes(DataType dtype) {
299315 int data_bits = dtype.bits () * dtype.lanes ();
300316 // allow bool to exist
301317 if (dtype == DataType::Bool () || dtype == DataType::Int (4 ) || dtype == DataType::UInt (4 ) ||
302- dtype == DataType::Int (1 )) {
318+ dtype == DataType::Int (1 ) || dtype == DataType::NVFloat4E2M1 () ) {
303319 return 1 ;
304320 }
305321 ICHECK_EQ (data_bits % 8 , 0U ) << " Need to load/store by multiple of bytes" ;
@@ -385,6 +401,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
385401 return " e4m3_float" ;
386402 case DataType::kE5M2Float :
387403 return " e5m2_float" ;
404+ case DataType::kE2M1Float :
405+ return " e2m1_float" ;
388406 default :
389407 LOG (FATAL) << " unknown type_code=" << static_cast <int >(type_code);
390408 }
@@ -466,6 +484,10 @@ inline DLDataType String2DLDataType(std::string s) {
466484 t.code = DataType::kE5M2Float ;
467485 t.bits = 8 ;
468486 scan = s.c_str () + 10 ;
487+ } else if (s.substr (0 , 10 ) == " e2m1_float" ) {
488+ t.code = DataType::kE2M1Float ;
489+ t.bits = 4 ;
490+ scan = s.c_str () + 10 ;
469491 } else if (s.substr (0 , 6 ) == " custom" ) {
470492 t.code = ParseCustomDatatype (s, &scan);
471493 } else {
0 commit comments