@@ -71,12 +71,14 @@ def initialize_model_and_tokenizer(model_name_or_path):
7171    parser .add_argument ("--device_map" , type = str , default = None , help = "device map for model" )
7272    parser .add_argument ("--use_recipe" , action = "store_true" , help = "whether to use recipe to quantize model" )
7373    parser .add_argument ("--recipe_file" , type = str , default = "recipes/Meta-Llama-3.1-8B-Instruct_6bits.json" , help = "path of recipe file" )
74+     parser .add_argument ("--mem_per_param_scale" , default = 13 , type = int , help = "memory per param scale factor" )
7475    parser .add_argument ("--iters" , default = 200 , type = int , help = "iters for autoround." )
7576    parser .add_argument ("--seqlen" , default = 2048 , type = int , help = "sequence length for autoround." )
7677    parser .add_argument ("--nsamples" , default = 128 , type = int , help = "number of samples for autoround." )
7778    parser .add_argument ("--save" , action = "store_true" , help = "whether to save the quantized model" )
7879    parser .add_argument ("--save_path" , type = str , default = "saved_results" , help = "path to save the quantized model" )
7980    parser .add_argument ("--save_format" , type = str , default = "auto_round" , help = "format to save the quantized model" )
81+     parser .add_argument ("--enable_torch_compile" , action = "store_true" , help = "whether to enable torch.compile" )
8082    parser .add_argument ("--quant_lm_head" , action = "store_true" , help = "whether to quantize lm_head" )
8183    parser .add_argument ("--accuracy" , action = "store_true" , help = "accuracy measurement" )
8284    parser .add_argument ("--local_rank" , type = int , default = 0 , metavar = "N" , help = "Local process rank." )
@@ -101,23 +103,29 @@ def initialize_model_and_tokenizer(model_name_or_path):
101103
102104    model , tokenizer  =  initialize_model_and_tokenizer (args .model_name_or_path )
103105    device = "hpu"  if  is_hpex_available () else  "cuda" 
106+     # in case that model is set to cuda:0 by default 
107+     if  args .device_map .isdigit () and  device == "cuda" :
108+         device  =  f"{ device } { args .device_map }  
104109
105110    if  args .quantize :
106-         autoround_dtype_mapping  =  {
107-             "MXFP4" : "mx_fp4" ,
108-             "MXFP8" : "mx_fp8" ,
109-             "NVFP4" : "nv_fp4" ,
110-             "uNVFP4" : "fp4_v2" ,
111-             "NVFP4+" : "fp4_v2" ,
112-         }
113-         args .dtype  =  autoround_dtype_mapping [args .dtype ]
111+         if  args .dtype  in  ["uNVFP4" , "NVFP4+" ]:
112+             from  auto_round .schemes  import  QuantizationScheme 
113+ 
114+             uNVFP4  =  QuantizationScheme .from_dict (
115+                 {
116+                     "bits" : 4 ,
117+                     "group_size" : 16 ,
118+                     "data_type" : "fp4_v2" ,
119+                     "act_bits" : 4 ,
120+                     "act_data_type" : "fp4_v2" ,
121+                     "act_group_size" : 16 ,
122+                     "act_sym" : True ,
123+                 }
124+             )
125+             args .dtype  =  uNVFP4 
126+ 
114127        if  args .quant_lm_head :
115-             lm_head_config  =  {
116-                 "group_size" : 32  if  "mx"  in  args .dtype  else  16 ,
117-                 "data_type" : args .dtype ,
118-                 "act_data_type" : "fp4_v2_with_global_scale"  if  "fp4_v2"  in  args .dtype  else  args .dtype ,
119-             }
120-             layer_config  =  {"lm_head" : lm_head_config }
128+             layer_config  =  {"lm_head" : args .dtype }
121129
122130        autoround  =  AutoRound (
123131            model ,
@@ -128,10 +136,10 @@ def initialize_model_and_tokenizer(model_name_or_path):
128136            seqlen = args .seqlen ,
129137            nsamples = args .nsamples ,
130138            low_gpu_mem_usage = True ,
131-             group_size = 32  if  "mx"  in  args .dtype  else  16 ,
132-             data_type = args .dtype ,
133-             act_data_type = "fp4_v2_with_global_scale"  if  "fp4_v2"  in  args .dtype  else  args .dtype ,
139+             scheme = args .dtype ,
134140            layer_config = layer_config  if  args .quant_lm_head  else  None ,
141+             enable_torch_compile = args .enable_torch_compile ,
142+             mem_per_param_scale = args .mem_per_param_scale ,
135143        )
136144
137145        if  args .use_recipe :
@@ -140,20 +148,16 @@ def load_recipe_results(file_path):
140148                import  json 
141149                with  open (file_path , "r" ) as  f :
142150                    return  json .load (f )
143-                  
151+ 
144152            layer_config  =  load_recipe_results (args .recipe_file )
145153            if  args .quant_lm_head :
146-                 mxfp8_config  =  {
147-                     "bits" : 8 ,
148-                     "group_size" : 32 ,
149-                     "data_type" : "mx_fp8" ,
150-                     "act_data_type" : "mx_fp8" ,
151-                 }
152154                # ensure lm_head is quantized with mxfp8_config 
153-                 layer_config .update ({"lm_head" : mxfp8_config })
155+                 layer_config .update ({"lm_head" : "MXFP8" })
154156                print ("In recipe mode, lm_head is quantized with MXFP8." )
155157            autoround .layer_config  =  layer_config 
156158
159+         # A placeholder, to pass assertion in AutoRound 
160+         autoround .formats  =  "auto_round" 
157161        autoround .quantize ()
158162        model  =  autoround .model 
159163
@@ -192,7 +196,6 @@ def load_recipe_results(file_path):
192196        else :
193197            # CUDA evaluation support all tasks. 
194198            # gsm8k requires add_bos_token=False for better accuracy for llama model. 
195-             # model = torch.compile(model) 
196199            args .tasks  =  ["piqa" , "hellaswag" , "mmlu" , "gsm8k" ]
197200            all_accuracy  =  {}
198201            test_gsm8k  =  False 
@@ -243,7 +246,7 @@ def load_recipe_results(file_path):
243246            print (f"Overall accuracy: { sum (all_accuracy .values ())/ len (all_accuracy ):.4f}  )
244247
245248    if  args .save :
246-         if  args .dtype  ==  "nv_fp4 " :
249+         if  args .dtype  ==  "NVFP4 " :
247250            # using llm_compressor format to save nv_fp4 model 
248251            autoround .save_quantized (args .save_path , format = args .save_format )
249252        else :
0 commit comments