@@ -76,6 +76,46 @@ def create_quantize_func(
7676 return bb .finalize ()
7777
7878
79+ def create_dequantize_func (
80+ packed_weight_shape ,
81+ scale_shape ,
82+ dequantized_shape ,
83+ model_dtype ,
84+ quantize_dtype ,
85+ storage_dtype ,
86+ group_size ,
87+ num_elem_per_storage ,
88+ axis ,
89+ ) -> IRModule :
90+ if DataType (quantize_dtype ).type_code == DataTypeCode .E4M3Float :
91+ dequantize_func = dequantize_fp8x4_e4m3
92+ else :
93+ assert NotImplementedError ()
94+
95+ bb = relax .BlockBuilder () # pylint: disable=invalid-name
96+ packed_weight_var = relax .Var (
97+ "weight" , relax .TensorStructInfo (packed_weight_shape , storage_dtype )
98+ )
99+ scale_var = relax .Var ("scale" , relax .TensorStructInfo (scale_shape , model_dtype ))
100+ compute_dequantize = dequantize_func (
101+ packed_weight_shape ,
102+ scale_shape ,
103+ dequantized_shape ,
104+ model_dtype ,
105+ quantize_dtype ,
106+ storage_dtype ,
107+ group_size ,
108+ num_elem_per_storage ,
109+ axis ,
110+ )
111+ with bb .function (name = "main" , params = [packed_weight_var , scale_var ]):
112+ with bb .dataflow ():
113+ lv = compute_dequantize (bb , (packed_weight_var , scale_var ))
114+ gv = bb .emit_output (lv )
115+ bb .emit_func_output (gv )
116+ return bb .finalize ()
117+
118+
79119def quantize_fp8x4_e4m3 ( # pylint: disable=too-many-locals
80120 weight_shape : List [tir .PrimExpr ],
81121 model_dtype ,
@@ -135,9 +175,6 @@ def compute_quantize_weight(bb: relax.BlockBuilder, args: relax.expr.Expr):
135175 quantize_dtype ,
136176 )
137177 # quant.show()
138- # import ipdb
139-
140- # ipdb.set_trace()
141178
142179 global_var = bb .add_func (quant , "quantized_weight" )
143180 lv_quantized_weight = bb .emit (
@@ -161,6 +198,41 @@ def compute_transpose(quantized_weight: te.Tensor, scale: te.Tensor):
161198 return compute_scale , compute_quantize_weight , compute_transpose
162199
163200
201+ def dequantize_fp8x4_e4m3 ( # pylint: disable=too-many-locals
202+ packed_weight_shape : List [tir .PrimExpr ],
203+ scale_shape ,
204+ dequant_shape ,
205+ model_dtype ,
206+ quantize_dtype ,
207+ storage_dtype ,
208+ group_size ,
209+ num_elem_per_storage ,
210+ axis : int = - 1 ,
211+ ) -> Tuple [te .Tensor , te .Tensor ]:
212+ """Group quantization for weight tensor, defined in tensor expression."""
213+ axis = axis if axis >= 0 else len (shape ) + axis
214+
215+ def compute_dequantize_weight (bb : relax .BlockBuilder , args : relax .expr .Expr ):
216+ dequant = dequant_fp8x4_e4m3_sm90 (
217+ packed_weight_shape ,
218+ scale_shape ,
219+ dequant_shape ,
220+ group_size ,
221+ axis ,
222+ model_dtype ,
223+ storage_dtype ,
224+ quantize_dtype ,
225+ )
226+
227+ global_var = bb .add_func (dequant , "dequantize_weight" )
228+ lv_dequantized_weight = bb .emit (
229+ relax .call_tir (global_var , args , relax .TensorStructInfo (dequant_shape , model_dtype ))
230+ )
231+ return lv_dequantized_weight
232+
233+ return compute_dequantize_weight
234+
235+
164236def quant_and_pack_fp8x4_e4m3_sm90 (
165237 weight_shape ,
166238 packed_shape ,
@@ -175,6 +247,7 @@ def quant_and_pack_fp8x4_e4m3_sm90(
175247 vec_quantized_dtype = f"{ quantized_dtype } x{ vector_length } "
176248 vec_model_dtype = f"{ model_dtype } x{ vector_length } "
177249 num_elem_per_storage = vector_length
250+ # TODO(csullivan) assert on storage dtype / quantize type bytes == vector length
178251 assert (
179252 group_size % vector_length == 0
180253 ), f"Number of elements in a group must be divisible by fp8 vector length { vector_length } "
@@ -202,14 +275,11 @@ def quant_pack(
202275 storage_dtype ,
203276 T .Cast (
204277 vec_quantized_dtype ,
205- # Note: Using the colon here is a sugared way of writing T.ramp(v_i1, 1, vector_length)
206- # ie a vector load of A
207- A [v_i0 , v_i1 : v_i1 + vector_length ]
278+ A [v_i0 , T .ramp (v_i1 * vector_length , 1 , vector_length )]
208279 / scale [v_i0 , v_i1 * T .int64 (vector_length ) // T .int64 (group_size )],
209280 ),
210281 )
211282
212- quant_pack .show ()
213283 return quant_pack
214284
215285
@@ -251,7 +321,6 @@ def dequant(
251321 scale [v_i0 , v_i1 * T .int64 (vector_length ) // T .int64 (group_size )], vector_length
252322 )
253323
254- dequant .show ()
255324 return dequant
256325
257326
@@ -543,100 +612,21 @@ def add(
543612 tvm .testing .assert_allclose (c .numpy (), c_expected , atol = 1e-5 , rtol = 1e-5 )
544613
545614
546- @tvm .testing .requires_cuda_compute_version (8 )
547- def test_weight_scale ():
548- weight_shape = [32000 , 4096 ]
549- group_size = 32
550- axis = 1
551- scale_shape = [d // group_size if axis == i else d for i , d in enumerate (weight_shape )]
552- model_dtype = "float16"
553- storage_dtype = "uint32"
554- quantized_dtype = "e4m3_float8"
555-
556- # q_weight = fp8(weight_f16 / scale_f16)
557- # q_weight = fp8x4(weight_f16x4 / scale_f16x4)
558- vector_length = 4
559- vec_quantized_dtype = "e4m3_float8x4"
560- vec_model_dtype = "float16x4"
561- num_el_per_storage = 4
562-
563- @T .prim_func
564- def vectorized (
565- A : T .Buffer (weight_shape , model_dtype ),
566- scale : T .Buffer (scale_shape , model_dtype ),
567- compute : T .Buffer (
568- (T .int64 (weight_shape [0 ]), T .int64 (weight_shape [1 ] // num_el_per_storage )),
569- storage_dtype ,
570- ),
571- ):
572- T .func_attr ({"tir.noalias" : T .bool (True )})
573- # with T.block("root"):
574- # test = T.alloc_buffer(1, dtype=vec_model_dtype, scope="local")
575- for i0 , i1 in T .grid (T .int64 (weight_shape [0 ]), T .int64 (weight_shape [1 ])):
576- with T .block ("compute" ):
577- v_i0 = T .axis .spatial (T .int64 (weight_shape [0 ]), i0 )
578- v_i1 = T .axis .spatial (T .int64 (weight_shape [1 ] // vector_length ), i1 )
579- T .reads (
580- A [v_i0 , v_i1 : v_i1 + vector_length ], scale [v_i0 , v_i1 // T .int64 (group_size )]
581- )
582- T .writes (compute [v_i0 , v_i1 * vector_length ])
583- compute [v_i0 , v_i1 * vector_length ] = T .reinterpret (
584- storage_dtype ,
585- T .Cast (
586- vec_quantized_dtype ,
587- # Note: Using the colon here is a sugared way of writing T.ramp(v_i1, 1, vector_length)
588- # ie a vector load of A
589- A [v_i0 , v_i1 : v_i1 + vector_length ]
590- / scale [v_i0 , v_i1 // T .int64 (group_size )],
591- ),
592- )
593-
594- sch = tvm .tir .Schedule (vectorized )
595- block = sch .get_block ("compute" )
596- loops = sch .get_loops (block )
597- txo , txi = sch .split (loops [0 ], factors = [None , 256 ])
598- sch .bind (loops [1 ], "blockIdx.x" )
599- sch .bind (txi , "threadIdx.x" )
600- sch .mod .show ()
601-
602- # sch = tvm.tir.Schedule(main)
603- # block = sch.get_block("compute")
604- # loops = sch.get_loops(block)
605- # bx, tx, lanes = sch.split(loops[-1], factors=[None, 32, 4])
606- # w_l = sch.cache_read(block, 0, storage_scope="local")
607- # # s_l = sch.cache_read(block, 1, storage_scope="local")
608- # sch.compute_at(block=w_l, loop=tx)
609- # # sch.compute_at(block=s_l, loop=tx)
610- # sch.bind(bx, "blockIdx.x")
611- # sch.bind(tx, "threadIdx.x")
612- # # sch.vectorize(lanes)
613- # sch.mod.show()
614-
615- import ipdb
616-
617- ipdb .set_trace ()
618- target = "cuda"
619- f = tvm .build (sch .mod , target = target )
620- print (f .imported_modules [0 ].get_source ())
621-
622-
623- weight_shape = tvm .testing .parameter ([32000 , 4096 ], [4096 , 14336 ])
615+ weight_shape = tvm .testing .parameter ((32000 , 4096 ), (4096 , 14336 ))
624616
625617
626618@tvm .testing .requires_cuda_compute_version (8 )
627- def test_fp8_e4_quant_weight (weight_shape ):
619+ def test_fp8e4x4_quant_dequant_weight (weight_shape ):
628620 group_size = 32
629621 axis = 1
630622 scale_shape = [d // group_size if axis == i else d for i , d in enumerate (weight_shape )]
631623 model_dtype = "float16"
632624 storage_dtype = "uint32"
633625 quantize_dtype = "e4m3_float8"
634626 num_el_per_storage = 4
627+ max_int_value = 448
635628
636- # TODO(csullivan): check this
637- max_int_value = 448 if "e4m3" in quantize_dtype else 57344
638-
639- mod = create_quantize_func (
629+ quant_mod = create_quantize_func (
640630 weight_shape ,
641631 model_dtype ,
642632 quantize_dtype ,
@@ -647,33 +637,60 @@ def test_fp8_e4_quant_weight(weight_shape):
647637 axis ,
648638 output_transpose = False ,
649639 )
640+ # quant_mod.show()
650641
651642 target_str = "cuda"
652643 target = tvm .target .Target (target_str )
653644 dev = tvm .device (target_str , 0 )
654645 with target :
655- mod = dl .ApplyDefaultSchedule (
646+ quant_mod = dl .ApplyDefaultSchedule (
656647 dl .gpu .Reduction (),
657648 dl .gpu .GeneralReduction (),
658649 dl .gpu .Fallback (),
659- )(mod )
660-
661- mod .show ()
662-
663- f = tvm .build (mod ["compute_scale" ], target = target )
664- cuda_src = f .imported_modules [0 ].get_source ()
665- print (cuda_src )
650+ )(quant_mod )
651+ ex = relax .build (quant_mod , target = target )
652+ vm = relax .VirtualMachine (ex , dev )
666653
667- ex = relax .build (mod , target = target )
668-
669- vm = relax .VirtualMachine (ex , dev ) # pylint: disable=invalid-name
654+ def print_cuda (target , mod , name = None ):
655+ if name :
656+ mod = mod [name ]
657+ f = tvm .build (mod , target = target )
658+ cuda_src = f .imported_modules [0 ].get_source ()
659+ print (cuda_src )
670660
671661 weight_np = np .random .uniform (- 100 , 100 , weight_shape ).astype (model_dtype )
672662 weight = tvm .nd .array (weight_np , device = dev )
673663 quant_weight , scales = vm ["main" ](weight )
674664 quant_weight_np , scales_np = quant_weight .numpy (), scales .numpy ()
675665
676- print (quant_weight_np , scales_np )
666+ dequant_mod = create_dequantize_func (
667+ quant_weight .shape ,
668+ scales .shape ,
669+ weight .shape ,
670+ model_dtype ,
671+ quantize_dtype ,
672+ storage_dtype ,
673+ group_size ,
674+ num_el_per_storage ,
675+ axis ,
676+ )
677+ # dequant_mod.show()
678+
679+ with target :
680+ dequant_mod = dl .ApplyDefaultSchedule (
681+ dl .gpu .Reduction (),
682+ dl .gpu .GeneralReduction (),
683+ dl .gpu .Fallback (),
684+ )(dequant_mod )
685+ dequant_mod .show ()
686+
687+ print_cuda (target , dequant_mod , name = "dequant" )
688+
689+ ex = relax .build (dequant_mod , target = target )
690+ vm = relax .VirtualMachine (ex , dev )
691+ dequant_weight = vm ["main" ](quant_weight , scales )
692+ dequant_weight_np = dequant_weight .numpy ()
693+ tvm .testing .assert_allclose (weight_np , dequant_weight_np , atol = 10 , rtol = 5e-2 )
677694
678695
679696if __name__ == "__main__" :
0 commit comments