Skip to content

Commit 2a1d261

Browse files
authored
Smooth_quant update with Iterator (#259)
1 parent 2995e1c commit 2a1d261

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

python/mlc_llm/quantization/smooth_quantization.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""The SmoothQuant config"""
22

33
from dataclasses import dataclass
4-
from typing import List, Literal, Union, Dict, Any
4+
from typing import List, Literal, Union, Dict, Any, Iterator, Tuple
55
from collections import OrderedDict
66
import numpy as np
77
import os
@@ -75,7 +75,7 @@ def load_file(path):
7575
return loaded_dict
7676

7777

78-
def shard_smoothquant_params(tensor_parallel_shards, args):
78+
def shard_smoothquant_params(tensor_parallel_shards, args) -> Iterator[Tuple[str, NDArray]]:
7979
model_config = args.model.config.from_file(args.config)
8080
model_config.tensor_parallel_shards = tensor_parallel_shards
8181
model = args.model.model(model_config)
@@ -89,7 +89,6 @@ def shard_smoothquant_params(tensor_parallel_shards, args):
8989
smoothing_factors_dict, _ = tvmjs.load_ndarray_cache(f"{pth}/smooth/", tvm.cpu())
9090
scales_dict, _ = tvmjs.load_ndarray_cache(f"{pth}/quantize/", tvm.cpu())
9191

92-
out = OrderedDict()
9392
smooth_0_quants = ["smq_q8i8f16_0", "smq_e4m3_float8_0", "smq_e5m2_float8_0"]
9493
for name, param in model.state_dict().items():
9594
smooth_factor_names = param_to_smooth_factor["prefill"].pop(name, None)
@@ -115,18 +114,17 @@ def shard_smoothquant_params(tensor_parallel_shards, args):
115114
a_zps = _split_array(scales_dict[a_zp], tensor_parallel_shards)
116115
w_zps = _duplicate_array(scales_dict[w_zp], tensor_parallel_shards)
117116
for shard_idx in range(tensor_parallel_shards):
118-
out[_sharded_param_name(a_factor, shard_idx)] = a_factors[shard_idx]
119-
out[_sharded_param_name(w_factor, shard_idx)] = w_factors[shard_idx]
117+
yield _sharded_param_name(a_factor, shard_idx), a_factors[shard_idx]
118+
yield _sharded_param_name(w_factor, shard_idx), w_factors[shard_idx]
120119
if not args.quantization.name in smooth_0_quants:
121-
out[_sharded_param_name(w_scale, shard_idx)] = w_scales[shard_idx]
122-
out[_sharded_param_name(w_zp, shard_idx)] = w_zps[shard_idx]
120+
yield _sharded_param_name(w_scale, shard_idx), w_scales[shard_idx]
121+
yield _sharded_param_name(w_zp, shard_idx), w_zps[shard_idx]
123122
else:
124-
out[a_factor] = smoothing_factors_dict[a_factor]
125-
out[w_factor] = smoothing_factors_dict[w_factor]
123+
yield a_factor, smoothing_factors_dict[a_factor]
124+
yield w_factor, smoothing_factors_dict[w_factor]
126125
if not args.quantization.name in smooth_0_quants:
127-
out[w_scale] = scales_dict[w_scale]
128-
out[w_zp] = scales_dict[w_zp]
129-
return out
126+
yield w_scale, scales_dict[w_scale]
127+
yield w_zp, scales_dict[w_zp]
130128

131129

132130
def _create_smoothquant_func(

0 commit comments

Comments
 (0)