11"""The SmoothQuant config"""
22
33from dataclasses import dataclass
4- from typing import List , Literal , Union , Dict , Any
4+ from typing import List , Literal , Union , Dict , Any , Iterator , Tuple
55from collections import OrderedDict
66import numpy as np
77import os
@@ -75,7 +75,7 @@ def load_file(path):
7575 return loaded_dict
7676
7777
78- def shard_smoothquant_params (tensor_parallel_shards , args ):
78+ def shard_smoothquant_params (tensor_parallel_shards , args ) -> Iterator [ Tuple [ str , NDArray ]] :
7979 model_config = args .model .config .from_file (args .config )
8080 model_config .tensor_parallel_shards = tensor_parallel_shards
8181 model = args .model .model (model_config )
@@ -89,7 +89,6 @@ def shard_smoothquant_params(tensor_parallel_shards, args):
8989 smoothing_factors_dict , _ = tvmjs .load_ndarray_cache (f"{ pth } /smooth/" , tvm .cpu ())
9090 scales_dict , _ = tvmjs .load_ndarray_cache (f"{ pth } /quantize/" , tvm .cpu ())
9191
92- out = OrderedDict ()
9392 smooth_0_quants = ["smq_q8i8f16_0" , "smq_e4m3_float8_0" , "smq_e5m2_float8_0" ]
9493 for name , param in model .state_dict ().items ():
9594 smooth_factor_names = param_to_smooth_factor ["prefill" ].pop (name , None )
@@ -115,18 +114,17 @@ def shard_smoothquant_params(tensor_parallel_shards, args):
115114 a_zps = _split_array (scales_dict [a_zp ], tensor_parallel_shards )
116115 w_zps = _duplicate_array (scales_dict [w_zp ], tensor_parallel_shards )
117116 for shard_idx in range (tensor_parallel_shards ):
118- out [ _sharded_param_name (a_factor , shard_idx )] = a_factors [shard_idx ]
119- out [ _sharded_param_name (w_factor , shard_idx )] = w_factors [shard_idx ]
117+ yield _sharded_param_name (a_factor , shard_idx ), a_factors [shard_idx ]
118+ yield _sharded_param_name (w_factor , shard_idx ), w_factors [shard_idx ]
120119 if not args .quantization .name in smooth_0_quants :
121- out [ _sharded_param_name (w_scale , shard_idx )] = w_scales [shard_idx ]
122- out [ _sharded_param_name (w_zp , shard_idx )] = w_zps [shard_idx ]
120+ yield _sharded_param_name (w_scale , shard_idx ), w_scales [shard_idx ]
121+ yield _sharded_param_name (w_zp , shard_idx ), w_zps [shard_idx ]
123122 else :
124- out [ a_factor ] = smoothing_factors_dict [a_factor ]
125- out [ w_factor ] = smoothing_factors_dict [w_factor ]
123+ yield a_factor , smoothing_factors_dict [a_factor ]
124+ yield w_factor , smoothing_factors_dict [w_factor ]
126125 if not args .quantization .name in smooth_0_quants :
127- out [w_scale ] = scales_dict [w_scale ]
128- out [w_zp ] = scales_dict [w_zp ]
129- return out
126+ yield w_scale , scales_dict [w_scale ]
127+ yield w_zp , scales_dict [w_zp ]
130128
131129
132130def _create_smoothquant_func (
0 commit comments