Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def itrex_bootstrap_stderr(f, xs, iters):
if args.approach in ["dynamic", "static"]:
print("device:", next(user_model.parameters()).device)
from neural_compressor.torch.quantization.config import FP8QConfig, get_default_fp8_qconfig
from neural_compressor.torch.quantization.fp8 import quantize_dynamic
from neural_compressor.torch.quantization import quantize, quantize_dynamic
from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic
from neural_compressor.torch.quantization import quantize
if args.precision == "fp8_e4m3":
dtype = torch.float8_e4m3fn
else:
Expand Down
18 changes: 0 additions & 18 deletions neural_compressor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from neural_compressor.torch.utils.utility import register_algo
from neural_compressor.torch.algorithms import rtn_quantize_entry, gptq_quantize_entry

from neural_compressor.torch.quantization import (
quantize,
RTNConfig,
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
StaticQuantConfig,
get_default_static_config,
SmoothQuantConfig,
get_default_sq_config,
)

from neural_compressor.common.base_tuning import TuningConfig
from neural_compressor.torch.quantization.autotune import autotune, get_all_config_set
4 changes: 0 additions & 4 deletions neural_compressor/torch/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from neural_compressor.torch.algorithms.weight_only_algos import rtn_quantize_entry
from neural_compressor.torch.algorithms.weight_only_algos import gptq_quantize_entry
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .quantization_impl import quantize_dynamic, quantize
from .fp8_quant import quantize_dynamic, quantize, white_list
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def convert(model, qconfig_mapping):
return model


@register_algo(name=FP8_QUANT)
def quantize(model, qconfig_mapping, run_fn=None, run_args=None, inplace=True):
q_model = model if inplace else copy.deepcopy(model)
q_model = prepare(q_model, qconfig_mapping)
Expand All @@ -209,7 +208,3 @@ def quantize(model, qconfig_mapping, run_fn=None, run_args=None, inplace=True):
run_fn(q_model)
q_model = convert(q_model, qconfig_mapping)
return q_model


# def autotune(fp32_model, quant_config, tune_config, eval_func, ...):
# pass
46 changes: 10 additions & 36 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,13 +757,14 @@ def find_params(self, x, weight=False):
if self.wdtype != "int":
from .utility import quant_tensor

tmp = x.clone() # make sure x is not replaced
tmp = x.clone() # tmp will be replaced after quant_tensor

_, scale, zero = quant_tensor(
tmp,
self.wbits,
self.group_size,
scheme=self.scheme,
dtype=self.wdtype,
bits=self.wbits,
group_size=self.group_size,
scheme=self.scheme,
quantile=1.0,
return_int=True,
full_range=False,
Expand Down Expand Up @@ -854,10 +855,10 @@ def find_params(self, x, weight=False):
self.scale = self.scale.reshape(1, -1)
quant_tensor(
self.scale,
self.double_quant_bits,
self.double_quant_group_size,
scheme=self.double_quant_scheme,
dtype=self.double_quant_dtype,
bits=self.double_quant_bits,
group_size=self.double_quant_group_size,
scheme=self.double_quant_scheme,
quantile=1.0,
return_int=False,
full_range=False,
Expand All @@ -879,8 +880,7 @@ def quantize(self, x, scale, zero, maxq):
if self.wdtype != "int":
from .utility import quantize_4bit

tmp = x.clone()

tmp = x.clone() # tmp will be replaced after quant_tensor
return quantize_4bit(tmp, dtype=self.wdtype, scale=scale)
else:
if maxq < 0:
Expand Down Expand Up @@ -950,33 +950,7 @@ def gptq_config_mapping(configs_mapping: Dict[Tuple[str, Callable], GPTQConfig])
return weight_config, nsamples, use_max_length, pad_max_length, device, dataloader_len


def gptq_quantize(
model,
weight_config={},
dataloader=None,
nsamples=128,
use_max_length=True,
pad_max_length=2048,
device=None,
layer_wise=False,
model_path=None,
):
"""Run weight-only quantization with."""
# TODO: unify weight_config keys, add docstring, and support default config
assert isinstance(model, torch.nn.Module), "only support torch module"
if layer_wise:
assert model_path is not None, "model_path should not be None when use layer_wise mode"
from .gptq import GPTQuantizer

gptq_quantizer = GPTQuantizer(
model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device, layer_wise=layer_wise
)
fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization(model_path=model_path)
logger.info("GPTQ quantizing done.")
return fp32_modified_model, gptq_config


def apply_gptq_quantize(model, configs_mapping, *args, **kwargs):
def gptq_quantize(model, configs_mapping, *args, **kwargs):
"""Apply gptq."""
# TODO: unify weight_config keys, add docstring, and support default config
weight_config, nsamples, use_max_length, pad_max_length, device, dataloader_len = gptq_config_mapping(
Expand Down
151 changes: 63 additions & 88 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,33 @@

import torch

from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils.utility import set_module
from neural_compressor.torch.utils import logger, set_module

from .utility import quant_tensor, search_clip


@torch.no_grad()
def rtn_quantize(
model,
num_bits=4,
dtype="int",
bits=4,
scheme="sym",
group_size=32,
scheme="asym",
group_dim=1,
quantile=1.0,
weight_config={},
return_int=False,
dtype="int",
enable_full_range=False,
enable_mse_search=False,
group_dim=1,
export_compressed_model=False,
use_full_range=False,
use_mse_search=False,
**kwargs,
):
"""Quant the model with round to nearest method.
"""Quant the model with round to nearest method and inplace is True.

Args:
model: torch module
num_bits: num bits. Defaults to 4.
bits: num bits. Defaults to 4.
group_size (int, optional): how many elements share one scale/zp. Defaults to 32.
scheme (str, optional): sym or asym. Defaults to "asym".
scheme (str, optional): sym or asym. Defaults to "sym".
quantile (float, optional): percentile of clip. Defaults to 1.0.
dtype (str, optional): select from int, nf4, fp4. Defaults to int.
weight_config (dict, optional): specific layer wise configurations. Defaults to {}.
Expand All @@ -60,88 +59,98 @@ def rtn_quantize(
'bits': 4,
'group_size': 32,
'scheme': 'sym'
'gptq_perm': [1, 1, ...] # for gptq perm
}
}
return_int (bool, optional): Choose return fp32 or int32 model.
export_compressed_model (bool, optional): Choose return fp32 or int32 model.
Defaults to False.
enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
Defaults to False.
enable_mse_search (bool, optional): Whether search clip range.
use_mse_search (bool, optional): Whether search clip range.
Defaults to True.
group_dim (int, optional): 0 means splitting output channel,
1 means splitting input channel. Defaults to 1.

Returns:
model: fake quantized torch module
"""
device = "cpu"
assert isinstance(model, torch.nn.Module), "only support torch module"
supported_layers = ["Linear"]
double_quant_dtype = kwargs.get("double_quant_dtype", "fp32")
# initialize global configuration
double_quant_config = {
"double_quant": False if double_quant_dtype == "fp32" else True,
"double_quant_dtype": double_quant_dtype,
"double_quant_num_bits": kwargs.get("double_quant_num_bits", 8),
"double_quant": kwargs.get("use_double_quant", False),
"double_quant_dtype": kwargs.get("double_quant_dtype", "int"),
"double_quant_bits": kwargs.get("double_quant_bits", 8),
"double_quant_scheme": kwargs.get("double_quant_scheme", "sym"),
"double_quant_group_size": kwargs.get("double_quant_group_size", 256),
}
if return_int:
compression_dtype = kwargs.get("compression_dtype", torch.int32)
compression_dim = kwargs.get("compression_dim", 1)
scale_dtype = kwargs.get("scale_dtype", torch.float32)
device = kwargs.get("device", "cpu")
if export_compressed_model:
use_optimum_format = kwargs.get("use_optimum_format", True)
for name, m in model.named_modules():
if m.__class__.__name__ not in supported_layers:
continue
if name in weight_config: # pragma: no cover
# initialize op configuration
dtype = weight_config[name].get("dtype", "int")
num_bits = weight_config[name]["bits"]
bits = weight_config[name].get("bits", 4)
group_size = weight_config[name]["group_size"]
scheme = weight_config[name]["scheme"]
quantile = weight_config[name].get("quantile", 1.0)
group_dim = weight_config[name]["group_dim"]
use_full_range = weight_config[name]["use_full_range"]
use_mse_search = weight_config[name]["use_mse_search"]
use_layer_wise = weight_config[name]["use_layer_wise"]
export_compressed_model = weight_config[name]["export_compressed_model"]
if export_compressed_model:
use_optimum_format = kwargs.get("use_optimum_format", True)
# double quant config
double_quant_config = {
"double_quant": weight_config[name]["use_double_quant"],
"double_quant_dtype": weight_config[name]["double_quant_dtype"],
"double_quant_bits": weight_config[name]["double_quant_bits"],
"double_quant_scheme": weight_config[name]["double_quant_scheme"],
"double_quant_group_size": weight_config[name]["double_quant_group_size"],
}
log_msg = (
f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, "
+ f"scheme={scheme}, quantile={quantile}"
f"RTN quantization config: bits={bits}, group_size={group_size}, " + f"scheme={scheme}, quantile={quantile}"
)
if dtype != "int":
log_msg += f", dtype={dtype}"
elif scheme == "sym": # nf4/fp4 is always [-7,7]
log_msg += f", enable_full_range={enable_full_range}"
log_msg += f", use_full_range={use_full_range}"
if dtype == "fp32":
continue
logger.debug(f"RTN quantized module:{name, m}")
logger.debug(log_msg)
weight = m.weight.T if group_dim == 0 else m.weight
if enable_mse_search:
quantile = search_clip(m, num_bits, group_size, scheme, dtype, enable_full_range)
if return_int:
weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight
if use_mse_search:
quantile = search_clip(m, bits, group_size, scheme, dtype, use_full_range)
if export_compressed_model:
int_weight, scale, zp = quant_tensor(
weight,
num_bits,
group_size,
scheme,
quantile,
dtype=dtype,
bits=bits,
group_size=group_size,
scheme=scheme,
quantile=quantile,
return_int=True,
full_range=enable_full_range,
full_range=use_full_range,
**double_quant_config,
)
int_weight = int_weight.T if group_dim == 0 else int_weight
scale = scale.T if group_dim == 0 else scale
zp = zp.T if group_dim == 0 and zp is not None else zp
int_weight = int_weight.t_().contiguous() if group_dim == 0 else int_weight
scale = scale.t_().contiguous() if group_dim == 0 else scale
zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp
from neural_compressor.torch.quantization.layers import WeightOnlyLinear

new_module = WeightOnlyLinear(
m.in_features,
m.out_features,
num_bits,
group_size,
bits=bits,
group_size=group_size,
dtype=dtype,
zp=zp is not None,
bias=m.bias is not None,
compression_dtype=compression_dtype,
compression_dim=compression_dim,
scale_dtype=scale_dtype,
use_optimum_format=use_optimum_format,
device=device,
)
new_module.pack(int_weight, scale, zp, m.bias)
Expand All @@ -150,50 +159,16 @@ def rtn_quantize(
else:
set_module(model, name, new_module)
else:
q_weight = quant_tensor(
weight = quant_tensor(
weight,
num_bits,
group_size,
scheme,
quantile,
dtype=dtype,
full_range=enable_full_range,
bits=bits,
group_size=group_size,
scheme=scheme,
quantile=quantile,
full_range=use_full_range,
**double_quant_config,
)
q_weight = q_weight.T if group_dim == 0 else q_weight
m.weight.data.copy_(q_weight)
weight = weight.t_().contiguous() if group_dim == 0 else weight
m.weight.data.copy_(weight)
return model


from neural_compressor.torch.quantization.config import RTNConfig


def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
# TODO (Yi) remove it
enable_full_range = quant_config.enable_full_range
enable_mse_search = quant_config.enable_mse_search
group_dim = quant_config.group_dim
dtype = quant_config.weight_dtype
num_bits = quant_config.weight_bits
scheme = "sym" if quant_config.weight_sym else "asym"
group_size = quant_config.weight_group_size
return_int = quant_config.return_int
double_quant_dtype = quant_config.double_quant_dtype
double_quant_num_bits = quant_config.double_quant_bits
double_quant_scheme = "sym" if quant_config.double_quant_sym else "asym"
double_quant_group_size = quant_config.double_quant_group_size
return rtn_quantize(
module,
num_bits,
group_size,
scheme,
return_int=return_int,
dtype=dtype,
enable_full_range=enable_full_range,
enable_mse_search=enable_mse_search,
group_dim=group_dim,
double_quant_dtype=double_quant_dtype,
double_quant_scheme=double_quant_scheme,
double_quant_num_bits=double_quant_num_bits,
double_quant_group_size=double_quant_group_size,
)
Loading