Skip to content

Commit 7e475fe

Browse files
committed
refactor PT 3x structure and refine RTN
Signed-off-by: xin3he <[email protected]>
1 parent 31743fe commit 7e475fe

File tree

19 files changed

+501
-442
lines changed

19 files changed

+501
-442
lines changed

neural_compressor/torch/__init__.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from neural_compressor.torch.utils.utility import register_algo
16-
from neural_compressor.torch.algorithms import rtn_quantize_entry, gptq_quantize_entry
17-
18-
from neural_compressor.torch.quantization import (
19-
quantize,
20-
RTNConfig,
21-
get_default_rtn_config,
22-
GPTQConfig,
23-
get_default_gptq_config,
24-
StaticQuantConfig,
25-
get_default_static_config,
26-
SmoothQuantConfig,
27-
get_default_sq_config,
28-
)
29-
30-
from neural_compressor.common.base_tuning import TuningConfig
31-
from neural_compressor.torch.quantization.autotune import autotune, get_all_config_set

neural_compressor/torch/algorithms/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
16-
from neural_compressor.torch.algorithms.weight_only_algos import rtn_quantize_entry
17-
from neural_compressor.torch.algorithms.weight_only_algos import gptq_quantize_entry

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -757,13 +757,14 @@ def find_params(self, x, weight=False):
757757
if self.wdtype != "int":
758758
from .utility import quant_tensor
759759

760-
tmp = x.clone() # make sure x is not replaced
760+
tmp = x.clone() # tmp will be replaced after quant_tensor
761+
761762
_, scale, zero = quant_tensor(
762763
tmp,
763-
self.wbits,
764-
self.group_size,
765-
scheme=self.scheme,
766764
dtype=self.wdtype,
765+
bits=self.wbits,
766+
group_size=self.group_size,
767+
scheme=self.scheme,
767768
quantile=1.0,
768769
return_int=True,
769770
full_range=False,
@@ -854,10 +855,10 @@ def find_params(self, x, weight=False):
854855
self.scale = self.scale.reshape(1, -1)
855856
quant_tensor(
856857
self.scale,
857-
self.double_quant_bits,
858-
self.double_quant_group_size,
859-
scheme=self.double_quant_scheme,
860858
dtype=self.double_quant_dtype,
859+
bits=self.double_quant_bits,
860+
group_size=self.double_quant_group_size,
861+
scheme=self.double_quant_scheme,
861862
quantile=1.0,
862863
return_int=False,
863864
full_range=False,
@@ -879,8 +880,7 @@ def quantize(self, x, scale, zero, maxq):
879880
if self.wdtype != "int":
880881
from .utility import quantize_4bit
881882

882-
tmp = x.clone()
883-
883+
tmp = x.clone() # tmp will be replaced after quant_tensor
884884
return quantize_4bit(tmp, dtype=self.wdtype, scale=scale)
885885
else:
886886
if maxq < 0:

neural_compressor/torch/algorithms/weight_only/rtn.py

Lines changed: 63 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,33 @@
2121

2222
import torch
2323

24-
from neural_compressor.torch.utils import logger
25-
from neural_compressor.torch.utils.utility import set_module
24+
from neural_compressor.torch.utils import logger, set_module
2625

2726
from .utility import quant_tensor, search_clip
2827

2928

3029
@torch.no_grad()
3130
def rtn_quantize(
3231
model,
33-
num_bits=4,
32+
dtype="int",
33+
bits=4,
34+
scheme="sym",
3435
group_size=32,
35-
scheme="asym",
36+
group_dim=1,
3637
quantile=1.0,
3738
weight_config={},
38-
return_int=False,
39-
dtype="int",
40-
enable_full_range=False,
41-
enable_mse_search=False,
42-
group_dim=1,
39+
export_compressed_model=False,
40+
use_full_range=False,
41+
use_mse_search=False,
4342
**kwargs,
4443
):
45-
"""Quant the model with round to nearest method.
44+
"""Quant the model with round to nearest method and inplace is True.
4645
4746
Args:
4847
model: torch module
49-
num_bits: num bits. Defaults to 4.
48+
bits: num bits. Defaults to 4.
5049
group_size (int, optional): how many elements share one scale/zp. Defaults to 32.
51-
scheme (str, optional): sym or asym. Defaults to "asym".
50+
scheme (str, optional): sym or asym. Defaults to "sym".
5251
quantile (float, optional): percentile of clip. Defaults to 1.0.
5352
dtype (str, optional): select from int, nf4, fp4. Defaults to int.
5453
weight_config (dict, optional): specific layer wise configurations. Defaults to {}.
@@ -60,88 +59,98 @@ def rtn_quantize(
6059
'bits': 4,
6160
'group_size': 32,
6261
'scheme': 'sym'
63-
'gptq_perm': [1, 1, ...] # for gptq perm
6462
}
6563
}
66-
return_int (bool, optional): Choose return fp32 or int32 model.
64+
export_compressed_model (bool, optional): Choose return fp32 or int32 model.
6765
Defaults to False.
68-
enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
66+
use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
6967
Defaults to False.
70-
enable_mse_search (bool, optional): Whether search clip range.
68+
use_mse_search (bool, optional): Whether search clip range.
7169
Defaults to True.
7270
group_dim (int, optional): 0 means splitting output channel,
7371
1 means splitting input channel. Defaults to 1.
7472
7573
Returns:
7674
model: fake quantized torch module
7775
"""
76+
device = "cpu"
7877
assert isinstance(model, torch.nn.Module), "only support torch module"
7978
supported_layers = ["Linear"]
80-
double_quant_dtype = kwargs.get("double_quant_dtype", "fp32")
79+
# initialize global configuration
8180
double_quant_config = {
82-
"double_quant": False if double_quant_dtype == "fp32" else True,
83-
"double_quant_dtype": double_quant_dtype,
84-
"double_quant_num_bits": kwargs.get("double_quant_num_bits", 8),
81+
"double_quant": kwargs.get("use_double_quant", False),
82+
"double_quant_dtype": kwargs.get("double_quant_dtype", "int"),
83+
"double_quant_bits": kwargs.get("double_quant_bits", 8),
8584
"double_quant_scheme": kwargs.get("double_quant_scheme", "sym"),
8685
"double_quant_group_size": kwargs.get("double_quant_group_size", 256),
8786
}
88-
if return_int:
89-
compression_dtype = kwargs.get("compression_dtype", torch.int32)
90-
compression_dim = kwargs.get("compression_dim", 1)
91-
scale_dtype = kwargs.get("scale_dtype", torch.float32)
92-
device = kwargs.get("device", "cpu")
87+
if export_compressed_model:
88+
use_optimum_format = kwargs.get("use_optimum_format", True)
9389
for name, m in model.named_modules():
9490
if m.__class__.__name__ not in supported_layers:
9591
continue
9692
if name in weight_config: # pragma: no cover
93+
# initialize op configuration
9794
dtype = weight_config[name].get("dtype", "int")
98-
num_bits = weight_config[name]["bits"]
95+
bits = weight_config[name].get("bits", 4)
9996
group_size = weight_config[name]["group_size"]
10097
scheme = weight_config[name]["scheme"]
10198
quantile = weight_config[name].get("quantile", 1.0)
99+
group_dim = weight_config[name]["group_dim"]
100+
use_full_range = weight_config[name]["use_full_range"]
101+
use_mse_search = weight_config[name]["use_mse_search"]
102+
use_layer_wise = weight_config[name]["use_layer_wise"]
103+
export_compressed_model = weight_config[name]["export_compressed_model"]
104+
if export_compressed_model:
105+
use_optimum_format = kwargs.get("use_optimum_format", True)
106+
# double quant config
107+
double_quant_config = {
108+
"double_quant": weight_config[name]["use_double_quant"],
109+
"double_quant_dtype": weight_config[name]["double_quant_dtype"],
110+
"double_quant_bits": weight_config[name]["double_quant_bits"],
111+
"double_quant_scheme": weight_config[name]["double_quant_scheme"],
112+
"double_quant_group_size": weight_config[name]["double_quant_group_size"],
113+
}
102114
log_msg = (
103-
f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, "
104-
+ f"scheme={scheme}, quantile={quantile}"
115+
f"RTN quantization config: bits={bits}, group_size={group_size}, " + f"scheme={scheme}, quantile={quantile}"
105116
)
106117
if dtype != "int":
107118
log_msg += f", dtype={dtype}"
108119
elif scheme == "sym": # nf4/fp4 is always [-7,7]
109-
log_msg += f", enable_full_range={enable_full_range}"
120+
log_msg += f", use_full_range={use_full_range}"
110121
if dtype == "fp32":
111122
continue
112123
logger.debug(f"RTN quantized module:{name, m}")
113124
logger.debug(log_msg)
114-
weight = m.weight.T if group_dim == 0 else m.weight
115-
if enable_mse_search:
116-
quantile = search_clip(m, num_bits, group_size, scheme, dtype, enable_full_range)
117-
if return_int:
125+
weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight
126+
if use_mse_search:
127+
quantile = search_clip(m, bits, group_size, scheme, dtype, use_full_range)
128+
if export_compressed_model:
118129
int_weight, scale, zp = quant_tensor(
119130
weight,
120-
num_bits,
121-
group_size,
122-
scheme,
123-
quantile,
124131
dtype=dtype,
132+
bits=bits,
133+
group_size=group_size,
134+
scheme=scheme,
135+
quantile=quantile,
125136
return_int=True,
126-
full_range=enable_full_range,
137+
full_range=use_full_range,
127138
**double_quant_config,
128139
)
129-
int_weight = int_weight.T if group_dim == 0 else int_weight
130-
scale = scale.T if group_dim == 0 else scale
131-
zp = zp.T if group_dim == 0 and zp is not None else zp
140+
int_weight = int_weight.t_().contiguous() if group_dim == 0 else int_weight
141+
scale = scale.t_().contiguous() if group_dim == 0 else scale
142+
zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp
132143
from neural_compressor.torch.quantization.layers import WeightOnlyLinear
133144

134145
new_module = WeightOnlyLinear(
135146
m.in_features,
136147
m.out_features,
137-
num_bits,
138-
group_size,
148+
bits=bits,
149+
group_size=group_size,
139150
dtype=dtype,
140151
zp=zp is not None,
141152
bias=m.bias is not None,
142-
compression_dtype=compression_dtype,
143-
compression_dim=compression_dim,
144-
scale_dtype=scale_dtype,
153+
use_optimum_format=use_optimum_format,
145154
device=device,
146155
)
147156
new_module.pack(int_weight, scale, zp, m.bias)
@@ -150,50 +159,16 @@ def rtn_quantize(
150159
else:
151160
set_module(model, name, new_module)
152161
else:
153-
q_weight = quant_tensor(
162+
weight = quant_tensor(
154163
weight,
155-
num_bits,
156-
group_size,
157-
scheme,
158-
quantile,
159164
dtype=dtype,
160-
full_range=enable_full_range,
165+
bits=bits,
166+
group_size=group_size,
167+
scheme=scheme,
168+
quantile=quantile,
169+
full_range=use_full_range,
161170
**double_quant_config,
162171
)
163-
q_weight = q_weight.T if group_dim == 0 else q_weight
164-
m.weight.data.copy_(q_weight)
172+
weight = weight.t_().contiguous() if group_dim == 0 else weight
173+
m.weight.data.copy_(weight)
165174
return model
166-
167-
168-
from neural_compressor.torch.quantization.config import RTNConfig
169-
170-
171-
def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
172-
# TODO (Yi) remove it
173-
enable_full_range = quant_config.enable_full_range
174-
enable_mse_search = quant_config.enable_mse_search
175-
group_dim = quant_config.group_dim
176-
dtype = quant_config.weight_dtype
177-
num_bits = quant_config.weight_bits
178-
scheme = "sym" if quant_config.weight_sym else "asym"
179-
group_size = quant_config.weight_group_size
180-
return_int = quant_config.return_int
181-
double_quant_dtype = quant_config.double_quant_dtype
182-
double_quant_num_bits = quant_config.double_quant_bits
183-
double_quant_scheme = "sym" if quant_config.double_quant_sym else "asym"
184-
double_quant_group_size = quant_config.double_quant_group_size
185-
return rtn_quantize(
186-
module,
187-
num_bits,
188-
group_size,
189-
scheme,
190-
return_int=return_int,
191-
dtype=dtype,
192-
enable_full_range=enable_full_range,
193-
enable_mse_search=enable_mse_search,
194-
group_dim=group_dim,
195-
double_quant_dtype=double_quant_dtype,
196-
double_quant_scheme=double_quant_scheme,
197-
double_quant_num_bits=double_quant_num_bits,
198-
double_quant_group_size=double_quant_group_size,
199-
)

0 commit comments

Comments
 (0)