@@ -68,7 +68,7 @@ class DataTypeCode(object):
6868 BFLOAT = 4
6969 E4M3Float = 6
7070 E5M2Float = 7
71- E2M1Float = 8
71+ FLOAT4E2M1FN = 8
7272
7373
7474class DataType (ctypes .Structure ):
@@ -83,7 +83,7 @@ class DataType(ctypes.Structure):
8383 DataTypeCode .BFLOAT : "bfloat" ,
8484 DataTypeCode .E4M3Float : "e4m3_float" ,
8585 DataTypeCode .E5M2Float : "e5m2_float" ,
86- DataTypeCode .E2M1Float : "e2m1_float " ,
86+ DataTypeCode .FLOAT4E2M1FN : "float4_e2m1fn " ,
8787 }
8888 NUMPY2STR = {
8989 np .dtype (np .bool_ ): "bool" ,
@@ -114,7 +114,7 @@ class DataType(ctypes.Structure):
114114 "uint64" : {"type_code" : DataTypeCode .UINT , "bits" : 64 , "lanes" : 1 },
115115 "e4m3_float8" : {"type_code" : DataTypeCode .E4M3Float , "bits" : 8 , "lanes" : 1 },
116116 "e5m2_float8" : {"type_code" : DataTypeCode .E5M2Float , "bits" : 8 , "lanes" : 1 },
117- "e2m1_float4 " : {"type_code" : DataTypeCode .E2M1Float , "bits" : 4 , "lanes" : 1 },
117+ "float4_e2m1fn " : {"type_code" : DataTypeCode .FLOAT4E2M1FN , "bits" : 4 , "lanes" : 1 },
118118 "float16" : {"type_code" : DataTypeCode .FLOAT , "bits" : 16 , "lanes" : 1 },
119119 "float32" : {"type_code" : DataTypeCode .FLOAT , "bits" : 32 , "lanes" : 1 },
120120 "float64" : {"type_code" : DataTypeCode .FLOAT , "bits" : 64 , "lanes" : 1 },
@@ -155,6 +155,11 @@ def __init__(self, type_str):
155155 elif head .startswith ("uint" ):
156156 self .type_code = DataTypeCode .UINT
157157 head = head [4 :]
158+ elif head .startswith ("float4_e2m1fn" ):
159+ # Avoid being treated as "float"
160+ self .type_code = DataTypeCode .FLOAT4E2M1FN
161+ bits = 4
162+ head = ""
158163 elif head .startswith ("float" ):
159164 self .type_code = DataTypeCode .FLOAT
160165 head = head [5 :]
@@ -171,9 +176,6 @@ def __init__(self, type_str):
171176 elif head .startswith ("e5m2_float" ):
172177 self .type_code = DataTypeCode .E5M2Float
173178 head = head [10 :]
174- elif head .startswith ("e2m1_float" ):
175- self .type_code = DataTypeCode .E2M1Float
176- head = head [10 :]
177179 elif head .startswith ("custom" ):
178180 # pylint: disable=import-outside-toplevel
179181 import tvm .runtime ._ffi_api
@@ -201,7 +203,12 @@ def __repr__(self):
201203 import tvm .runtime ._ffi_api
202204
203205 type_name = "custom[%s]" % tvm .runtime ._ffi_api ._datatype_get_type_name (self .type_code )
204- x = "%s%d" % (type_name , self .bits )
206+ if self .type_code in [
207+ DataTypeCode .FLOAT4E2M1FN ,
208+ ]:
209+ x = type_name
210+ else :
211+ x = "%s%d" % (type_name , self .bits )
205212 lanes_as_int = ctypes .c_int16 (self .lanes ).value
206213 if lanes_as_int > 1 :
207214 x += "x%d" % self .lanes
@@ -238,7 +245,7 @@ def itemsize(self):
238245 DataType .NUMPY2STR [np .dtype (ml_dtypes .bfloat16 )] = "bfloat16"
239246 DataType .NUMPY2STR [np .dtype (ml_dtypes .float8_e4m3fn )] = "e4m3_float8"
240247 DataType .NUMPY2STR [np .dtype (ml_dtypes .float8_e5m2 )] = "e5m2_float8"
241- DataType .NUMPY2STR [np .dtype (ml_dtypes .float4_e2m1fn )] = "e2m1_float4 "
248+ DataType .NUMPY2STR [np .dtype (ml_dtypes .float4_e2m1fn )] = "float4_e2m1fn "
242249
243250RPC_SESS_MASK = 128
244251
0 commit comments