Skip to content
Merged
18 changes: 0 additions & 18 deletions neural_compressor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from neural_compressor.torch.utils.utility import register_algo
from neural_compressor.torch.algorithms import rtn_quantize_entry, gptq_quantize_entry

from neural_compressor.torch.quantization import (
quantize,
RTNConfig,
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
StaticQuantConfig,
get_default_static_config,
SmoothQuantConfig,
get_default_sq_config,
)

from neural_compressor.common.base_tuning import TuningConfig
from neural_compressor.torch.quantization.autotune import autotune, get_all_config_set
4 changes: 0 additions & 4 deletions neural_compressor/torch/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from neural_compressor.torch.algorithms.weight_only_algos import rtn_quantize_entry
from neural_compressor.torch.algorithms.weight_only_algos import gptq_quantize_entry
18 changes: 9 additions & 9 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,13 +757,14 @@ def find_params(self, x, weight=False):
if self.wdtype != "int":
from .utility import quant_tensor

tmp = x.clone() # make sure x is not replaced
tmp = x.clone() # tmp will be replaced after quant_tensor

_, scale, zero = quant_tensor(
tmp,
self.wbits,
self.group_size,
scheme=self.scheme,
dtype=self.wdtype,
bits=self.wbits,
group_size=self.group_size,
scheme=self.scheme,
quantile=1.0,
return_int=True,
full_range=False,
Expand Down Expand Up @@ -854,10 +855,10 @@ def find_params(self, x, weight=False):
self.scale = self.scale.reshape(1, -1)
quant_tensor(
self.scale,
self.double_quant_bits,
self.double_quant_group_size,
scheme=self.double_quant_scheme,
dtype=self.double_quant_dtype,
bits=self.double_quant_bits,
group_size=self.double_quant_group_size,
scheme=self.double_quant_scheme,
quantile=1.0,
return_int=False,
full_range=False,
Expand All @@ -879,8 +880,7 @@ def quantize(self, x, scale, zero, maxq):
if self.wdtype != "int":
from .utility import quantize_4bit

tmp = x.clone()

tmp = x.clone() # tmp will be replaced after quant_tensor
return quantize_4bit(tmp, dtype=self.wdtype, scale=scale)
else:
if maxq < 0:
Expand Down
151 changes: 63 additions & 88 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,33 @@

import torch

from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils.utility import set_module
from neural_compressor.torch.utils import logger, set_module

from .utility import quant_tensor, search_clip


@torch.no_grad()
def rtn_quantize(
model,
num_bits=4,
dtype="int",
bits=4,
scheme="sym",
group_size=32,
scheme="asym",
group_dim=1,
quantile=1.0,
weight_config={},
return_int=False,
dtype="int",
enable_full_range=False,
enable_mse_search=False,
group_dim=1,
export_compressed_model=False,
use_full_range=False,
use_mse_search=False,
**kwargs,
):
"""Quant the model with round to nearest method.
"""Quant the model with round to nearest method and inplace is True.

Args:
model: torch module
num_bits: num bits. Defaults to 4.
bits: num bits. Defaults to 4.
group_size (int, optional): how many elements share one scale/zp. Defaults to 32.
scheme (str, optional): sym or asym. Defaults to "asym".
scheme (str, optional): sym or asym. Defaults to "sym".
quantile (float, optional): percentile of clip. Defaults to 1.0.
dtype (str, optional): select from int, nf4, fp4. Defaults to int.
weight_config (dict, optional): specific layer wise configurations. Defaults to {}.
Expand All @@ -60,88 +59,98 @@ def rtn_quantize(
'bits': 4,
'group_size': 32,
'scheme': 'sym'
'gptq_perm': [1, 1, ...] # for gptq perm
}
}
return_int (bool, optional): Choose return fp32 or int32 model.
export_compressed_model (bool, optional): Choose return fp32 or int32 model.
Defaults to False.
enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
Defaults to False.
enable_mse_search (bool, optional): Whether search clip range.
use_mse_search (bool, optional): Whether search clip range.
Defaults to True.
group_dim (int, optional): 0 means splitting output channel,
1 means splitting input channel. Defaults to 1.

Returns:
model: fake quantized torch module
"""
device = "cpu"
assert isinstance(model, torch.nn.Module), "only support torch module"
supported_layers = ["Linear"]
double_quant_dtype = kwargs.get("double_quant_dtype", "fp32")
# initialize global configuration
double_quant_config = {
"double_quant": False if double_quant_dtype == "fp32" else True,
"double_quant_dtype": double_quant_dtype,
"double_quant_num_bits": kwargs.get("double_quant_num_bits", 8),
"double_quant": kwargs.get("use_double_quant", False),
"double_quant_dtype": kwargs.get("double_quant_dtype", "int"),
"double_quant_bits": kwargs.get("double_quant_bits", 8),
"double_quant_scheme": kwargs.get("double_quant_scheme", "sym"),
"double_quant_group_size": kwargs.get("double_quant_group_size", 256),
}
if return_int:
compression_dtype = kwargs.get("compression_dtype", torch.int32)
compression_dim = kwargs.get("compression_dim", 1)
scale_dtype = kwargs.get("scale_dtype", torch.float32)
device = kwargs.get("device", "cpu")
if export_compressed_model:
use_optimum_format = kwargs.get("use_optimum_format", True)
for name, m in model.named_modules():
if m.__class__.__name__ not in supported_layers:
continue
if name in weight_config: # pragma: no cover
# initialize op configuration
dtype = weight_config[name].get("dtype", "int")
num_bits = weight_config[name]["bits"]
bits = weight_config[name].get("bits", 4)
group_size = weight_config[name]["group_size"]
scheme = weight_config[name]["scheme"]
quantile = weight_config[name].get("quantile", 1.0)
group_dim = weight_config[name]["group_dim"]
use_full_range = weight_config[name]["use_full_range"]
use_mse_search = weight_config[name]["use_mse_search"]
use_layer_wise = weight_config[name]["use_layer_wise"]
export_compressed_model = weight_config[name]["export_compressed_model"]
if export_compressed_model:
use_optimum_format = kwargs.get("use_optimum_format", True)
# double quant config
double_quant_config = {
"double_quant": weight_config[name]["use_double_quant"],
"double_quant_dtype": weight_config[name]["double_quant_dtype"],
"double_quant_bits": weight_config[name]["double_quant_bits"],
"double_quant_scheme": weight_config[name]["double_quant_scheme"],
"double_quant_group_size": weight_config[name]["double_quant_group_size"],
}
log_msg = (
f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, "
+ f"scheme={scheme}, quantile={quantile}"
f"RTN quantization config: bits={bits}, group_size={group_size}, " + f"scheme={scheme}, quantile={quantile}"
)
if dtype != "int":
log_msg += f", dtype={dtype}"
elif scheme == "sym": # nf4/fp4 is always [-7,7]
log_msg += f", enable_full_range={enable_full_range}"
log_msg += f", use_full_range={use_full_range}"
if dtype == "fp32":
continue
logger.debug(f"RTN quantized module:{name, m}")
logger.debug(log_msg)
weight = m.weight.T if group_dim == 0 else m.weight
if enable_mse_search:
quantile = search_clip(m, num_bits, group_size, scheme, dtype, enable_full_range)
if return_int:
weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight
if use_mse_search:
quantile = search_clip(m, bits, group_size, scheme, dtype, use_full_range)
if export_compressed_model:
int_weight, scale, zp = quant_tensor(
weight,
num_bits,
group_size,
scheme,
quantile,
dtype=dtype,
bits=bits,
group_size=group_size,
scheme=scheme,
quantile=quantile,
return_int=True,
full_range=enable_full_range,
full_range=use_full_range,
**double_quant_config,
)
int_weight = int_weight.T if group_dim == 0 else int_weight
scale = scale.T if group_dim == 0 else scale
zp = zp.T if group_dim == 0 and zp is not None else zp
int_weight = int_weight.t_().contiguous() if group_dim == 0 else int_weight
scale = scale.t_().contiguous() if group_dim == 0 else scale
zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp
from neural_compressor.torch.quantization.layers import WeightOnlyLinear

new_module = WeightOnlyLinear(
m.in_features,
m.out_features,
num_bits,
group_size,
bits=bits,
group_size=group_size,
dtype=dtype,
zp=zp is not None,
bias=m.bias is not None,
compression_dtype=compression_dtype,
compression_dim=compression_dim,
scale_dtype=scale_dtype,
use_optimum_format=use_optimum_format,
device=device,
)
new_module.pack(int_weight, scale, zp, m.bias)
Expand All @@ -150,50 +159,16 @@ def rtn_quantize(
else:
set_module(model, name, new_module)
else:
q_weight = quant_tensor(
weight = quant_tensor(
weight,
num_bits,
group_size,
scheme,
quantile,
dtype=dtype,
full_range=enable_full_range,
bits=bits,
group_size=group_size,
scheme=scheme,
quantile=quantile,
full_range=use_full_range,
**double_quant_config,
)
q_weight = q_weight.T if group_dim == 0 else q_weight
m.weight.data.copy_(q_weight)
weight = weight.t_().contiguous() if group_dim == 0 else weight
m.weight.data.copy_(weight)
return model


from neural_compressor.torch.quantization.config import RTNConfig


def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
# TODO (Yi) remove it
enable_full_range = quant_config.enable_full_range
enable_mse_search = quant_config.enable_mse_search
group_dim = quant_config.group_dim
dtype = quant_config.weight_dtype
num_bits = quant_config.weight_bits
scheme = "sym" if quant_config.weight_sym else "asym"
group_size = quant_config.weight_group_size
return_int = quant_config.return_int
double_quant_dtype = quant_config.double_quant_dtype
double_quant_num_bits = quant_config.double_quant_bits
double_quant_scheme = "sym" if quant_config.double_quant_sym else "asym"
double_quant_group_size = quant_config.double_quant_group_size
return rtn_quantize(
module,
num_bits,
group_size,
scheme,
return_int=return_int,
dtype=dtype,
enable_full_range=enable_full_range,
enable_mse_search=enable_mse_search,
group_dim=group_dim,
double_quant_dtype=double_quant_dtype,
double_quant_scheme=double_quant_scheme,
double_quant_num_bits=double_quant_num_bits,
double_quant_group_size=double_quant_group_size,
)
Loading