@@ -66,10 +66,12 @@ def __init__(
6666
6767    def  add_input (self , args ):
6868        if  self .inputs  is  None :
69-             self .inputs  =  [MultiInput ([arg ]) for  arg  in  args ]
69+             # self.inputs = [MultiInput([arg]) for arg in args] 
70+             self .inputs  =  [GPTQMultiTensor ([arg ]) for  arg  in  args ]
7071        else :
7172            self .inputs  =  [
72-                 multi .add_input (arg ) for  (multi , arg ) in  zip (self .inputs , args )
73+                 # multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) 
74+                 multi .add_tensors (arg ) for  (multi , arg ) in  zip (self .inputs , args )
7375            ]
7476
7577    def  get_recorded_inputs (self ):
@@ -129,6 +131,199 @@ def cuda(self):
129131        self .values  =  [val .cuda () if  isinstance (val , torch .Tensor ) else  val  for  val  in  self .values ]
130132
131133
134+ class  GPTQMultiTensor (torch .Tensor ):
135+     """ 
136+     """ 
137+     # todo need default shape/dtype 
138+     @staticmethod  
139+     def  __new__ (cls , input , ** kwargs ):
140+         if  isinstance (input , (list , tuple )):
141+             input  =  input [0 ]
142+         kwargs ["dtype" ]= kwargs .get ("dtype" , input .dtype )
143+         shape  =  kwargs .pop ("shape" , input .shape )
144+         return  torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs )
145+ 
146+     def  __init__ (self , input , ** kwargs ):
147+         self .values  =  []
148+         self .add_tensors (input )
149+         self .debug  =  True 
150+ 
151+     def  __repr__ (self ):
152+         return  (
153+             f"{ self .__class__ .__name__ }  (data={ self .values }  )" 
154+         )
155+ 
156+     def  add_tensors (self , input ):
157+         if  isinstance (input , (tuple , list )):
158+             for  inp  in  input :
159+                 self .add_tensors (inp )
160+         else :
161+             assert  isinstance (input , torch .Tensor ), f"MultiTensor can only use add_input for Tensors or lists of tensors but got { type (input )}  " 
162+             self .values .append (input )
163+         return  self 
164+ 
165+     def  count (self ):
166+         return  len (self .values )
167+ 
168+     def  cuda (self ):
169+         self .values  =  [val .cuda () for  val  in  self .values ]
170+         return  self 
171+ 
172+     def  cpu (self ):
173+         self .values  =  [val .cpu () for  val  in  self .values ]
174+         return  self 
175+ 
176+     def  configure_quantization_mode (
177+         self ,
178+         get_qparams_func ,
179+         quantize_func ,
180+         dequantize_func ,
181+         combine_qparams_list_func ,
182+         make_names_and_values_dict_func ,
183+         skip_layer_func ,
184+     ):
185+         self .get_qparams_func  =  get_qparams_func 
186+         self .quantize_func  =  quantize_func 
187+         self .dequantize_func  =  dequantize_func 
188+         self .combine_qparams_list_func  =  combine_qparams_list_func 
189+         self .skip_layer_func  =  skip_layer_func 
190+         self .make_names_and_values_dict_func  =  make_names_and_values_dict_func 
191+         return  self 
192+ 
193+     @classmethod  
194+     def  __torch_function__ (cls , func , types , args = (), kwargs = None , skip_gptq = False ):
195+         # with torch._C.DisableTorchFunctionSubclass(): 
196+         #     is_set_item = str(func)=="<slot wrapper '__setitem__' of 'torch._C.TensorBase' objects>" 
197+         # if is_set_item: 
198+         #     breakpoint() 
199+         #     try: 
200+         #         new_arg1=[None if x == slice(None) else x for x in args[1]] 
201+         #         return torch.ops.aten.index_put(args[0], new_arg1, args[2]) 
202+         #     except Exception as e: 
203+         #         print(e) 
204+         #         print("?A?") 
205+         #         breakpoint() 
206+         #         print("?") 
207+         # if func == torch.ops.aten.index_put_: 
208+         #     breakpoint() 
209+ 
210+         def  tensors_to_cuda (args ):
211+             new_args  =  []
212+             for  x  in  args :
213+                 new_args .append (x .cuda () if  isinstance (x , torch .Tensor ) else  x )
214+             return  new_args 
215+ 
216+         def  flat_to_grouped (flat ):
217+             # size of biggest MultiTensor 
218+             multi_tensor_size  =  max (
219+                 [x .count () if  isinstance (x , GPTQMultiTensor ) else  1  for  x  in  flat ]
220+             )
221+             # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]] 
222+             grouped  =  list (
223+                 zip (
224+                     * [x .values  if  isinstance (x , GPTQMultiTensor ) else  [x ] *  multi_tensor_size  for  x  in  flat ]
225+                 )
226+             )
227+             return  grouped 
228+ 
229+         # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] 
230+         # where A is nontensor, b's,c's are tensors 
231+         def  grouped_to_flat (grouped ):
232+             # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)] 
233+             flat_tups  =  list (zip (* grouped ))
234+             # convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] 
235+             flattened  =  [
236+                 cls (tup ).cpu () if  isinstance (tup [0 ], torch .Tensor ) else  tup [0 ] for  tup  in  flat_tups 
237+             ]
238+             # need to check that getting rid of all but one from each nonTensor tuple is OK 
239+             non_tensors_equal = min ([True ]+ [
240+                 min ([True ]+ [ # handle situation where tuples have size 0 
241+                     tup [0 ]== x  for  x  in  tup  # check all elements match 
242+                 ]) for  tup  in  flat_tups  if  not  isinstance (tup [0 ], torch .Tensor ) # look at tuples of nonTensors 
243+             ])
244+             return  flattened , non_tensors_equal 
245+ 
246+         kwargs  =  {} if  kwargs  is  None  else  kwargs 
247+         # combine args and kwargs and remove lists and tuples 
248+         flat_args , spec  =  tree_flatten ((args , kwargs ))
249+         # move single tensors to cuda 
250+ 
251+         # flat_args = tensors_to_cuda(flat_args) 
252+ 
253+         # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]] 
254+         grouped_args  =  flat_to_grouped (flat_args )
255+ 
256+         do_gptq_linear  =  (
257+             func  is  nn .functional .linear 
258+             # and id(args[1]) in self.id_to_name 
259+             and  not  skip_gptq 
260+             # and not (self.skip_layer_func) 
261+         )
262+ 
263+         # run function for each of the multitensors and return a multitensor 
264+         if  not  do_gptq_linear :
265+             outputs  =  []
266+             with  torch ._C .DisableTorchFunctionSubclass ():
267+                 for  inp  in  grouped_args :
268+                     # inp = tensors_to_cuda(inp) 
269+                     cur_args , cur_kwargs  =  tree_unflatten (inp , spec )
270+                     try :
271+                         out  =  func (* cur_args , ** cur_kwargs )
272+                         outputs .append (out .cpu () if  isinstance (out , torch .Tensor ) else  out )
273+                     except  Exception  as  e :
274+                         print (e )
275+                         print ("?B?" )
276+                         breakpoint ()
277+                         print ("?" )
278+                 try :
279+                     # each output 
280+                     grouped_outputs  =  [tree_flatten (x )[0 ] for  x  in  outputs ]
281+                     out_spec  =  tree_flatten (outputs [0 ])[1 ]
282+                     # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] 
283+                     flat_outputs , non_tensors_equal  =  grouped_to_flat (grouped_outputs )
284+                     assert  non_tensors_equal , (
285+                         f"ERR: found a function in model: { func }   which " 
286+                         + "caused an error in GPTQMultiInput, the function dispatch only works for functions" 
287+                         + " with Tensor outputs or that have the same non-Tensor output value for all across all inputs" 
288+                     )
289+                     return  tree_unflatten (flat_outputs , out_spec )
290+                 except  Exception  as  e :
291+                     print (e )
292+                     print ("?C?" )
293+                     breakpoint ()
294+                     print ("?" )
295+ 
296+         # do GPTQ if quantize_linear is true 
297+         total_batches  =  0 
298+         H = 0 
299+         for  inp  in  grouped_args :
300+             # inp = tensors_to_cuda(inp) 
301+             cur_args , cur_kwargs  =  tree_unflatten (inp , spec )
302+             x  =  cur_args [0 ].float ()
303+             shape  =  x .shape 
304+             n  =  1  if  len (shape ) ==  2  else  shape [0 ]
305+             H *=  total_batches  /  (total_batches  +  n )
306+             total_batches  +=  n 
307+             x  =  (
308+                 (2  /  total_batches ) **  (1  /  2 ) * 
309+                 x .reshape (- 1 , shape [- 1 ]).t ().float ()
310+ 
311+             )
312+             H  +=  x .matmul (x .t ())
313+         W  =  args [1 ].to (H .device )
314+         DQ  =  W + .01 
315+         # Q, DQ, qparams = args[0].faster_quant(H, W.detach()) 
316+ 
317+         new_out  =  cls .__torch_function__ (func , types , (args [0 ], DQ , * args [2 :]), kwargs , skip_gptq  =  True )
318+         # if args[0].debug: 
319+         return  new_out 
320+ 
321+     @classmethod  
322+     def  __torch_dispatch__ (cls , func , types , args , kwargs ):
323+         breakpoint ()
324+         pass 
325+ 
326+ 
132327class  GenericGPTQRunner (fx .Interpreter ):
133328    """ 
134329    This is a generic GPTQ runner that takes an existing model and applies GPTQ. 
@@ -150,7 +345,7 @@ def __init__(
150345        }
151346
152347        # trace model for one input 
153-         one_input  =  [multi .values [0 ].cpu () for  multi  in  inputs ]
348+         one_input  =  tuple ( [multi .values [0 ].cpu () for  multi  in  inputs ]) 
154349        exported_model  =  torch ._dynamo .export (
155350            model .cpu (), aten_graph = True , pre_dispatch = True , tracing_mode = "fake" 
156351        )(* one_input )
@@ -161,7 +356,7 @@ def __init__(
161356        self .groupsize  =  groupsize 
162357        self .inputs  =  inputs 
163358        self .gptq_done  =  False 
164-         self .debug  =  False 
359+         self .debug  =  True 
165360
166361    def  configure_quantization_mode (
167362        self ,
@@ -312,6 +507,16 @@ def SQNR(x, y):
312507                print (
313508                    "SQNR for QDQ (this should be inf)" , SQNR (DQ , DQ_after )
314509                )  # matches 
510+                 qparams_after  =  self .get_qparams_func (DQ )
511+                 Q_after  =  self .quantize_func (DQ , qparams_after )
512+                 print (
513+                     "abs difference of Q-quant(DQ)" , (Q - Q_after ).abs ().sum ()
514+                 )
515+                 DQ_after_after  =  self .dequantize_func (Q_after , qparams_after ).to (DQ .dtype )
516+                 print (
517+                     "SQNR for DQ(Q(DQ)) vs DQ" , SQNR (DQ , DQ_after_after )
518+                 )
519+                 breakpoint ()
315520
316521                print (
317522                    "SQNR for weight (can be low)" , SQNR (W , DQ .cuda ())
0 commit comments