Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import copy
import inspect
import json
import os
Expand Down Expand Up @@ -539,6 +540,7 @@ def expand(self) -> List[BaseConfig]:
tuning_param_pair = dict(zip(tuning_param_name_lst, params_values))
tmp_params_dict = {**not_tuning_param_pair, **tuning_param_pair}
new_config = self.__class__(**tmp_params_dict)
new_config.local_config = copy.deepcopy(self.local_config)
logger.info(new_config.to_dict())
config_list.append(new_config)
logger.info("Expanded the %s and got %d configs.", self.__class__.name, len(config_list))
Expand Down Expand Up @@ -629,9 +631,13 @@ def __eq__(self, other: BaseConfig) -> bool:
"""
if not isinstance(other, type(self)):
return False
return self.params_list == other.params_list and all(

params_equal = self.params_list == other.params_list and all(
getattr(self, str(attr)) == getattr(other, str(attr)) for attr in self.params_list
)
local_config_equal = self.local_config == other.local_config
global_config_equal = self.global_config == other.global_config
return params_equal and local_config_equal and global_config_equal


class ComposableConfig(BaseConfig):
Expand Down
4 changes: 3 additions & 1 deletion neural_compressor/torch/algorithms/pt2e_quant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from typing import Any

import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
Expand Down Expand Up @@ -102,4 +103,5 @@ def half_precision_transformation(self, model, config):
"""
half_precision_node_set = hp_rewriter.get_half_precision_node_set(model, config)
logger.info("Try to convert %d nodes to half precision.", len(half_precision_node_set))
hp_rewriter.transformation(model, half_precision_node_set)
hp_rewriter.transformation(model, half_precision_node_set, torch.float16)
hp_rewriter.transformation(model, half_precision_node_set, torch.bfloat16)
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _register_pattern_pair(dtype: torch.dtype) -> None:


_register_pattern_pair(torch.float16)
_register_pattern_pair(torch.bfloat16)


def get_filter_fn(node_list, fn):
Expand Down Expand Up @@ -201,11 +202,11 @@ def _parse_node_candidate_set_from_user_config(config, gm):
op_name_filters = []
for op_type_name, config in op_type_configs.items(): # pragma: no cover
op_type = getattr(torch.nn, op_type_name)
if config.act_dtype == "fp16": # pragma: no cover
if config.act_dtype in ["fp16", "bf16"]: # pragma: no cover
filter = xpq._get_module_type_filter(op_type)
op_type_filters.append(filter)
for op_name, config in op_name_configs.items():
if config.act_dtype == "fp16": # pragma: no cover
if config.act_dtype in ["fp16", "bf16"]: # pragma: no cover
filter = xpq._get_module_name_filter(op_name)
op_name_filters.append(filter)
node_set_from_user_config = set()
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/torch/algorithms/pt2e_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch.ao.quantization.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer

from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5
from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5, logger


def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
Expand Down Expand Up @@ -79,6 +79,7 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals
def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig:
NOT_QUANT_DTYPES = ["fp32", "fp16", "bf16"]
if inc_config.act_dtype in NOT_QUANT_DTYPES and inc_config.w_dtype in NOT_QUANT_DTYPES: # pragma: no cover
logger.debug("Got non-quantizable data types, skipping quantization.")
return None
default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic)
input_act_quant_spec = create_quant_spec_from_config(
Expand Down
15 changes: 12 additions & 3 deletions neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)


def _deepcopy_warp(model):
additional_attr_lst = ["_exported", "dynamic_shapes"]
original_attr = {key: getattr(model, key, None) for key in additional_attr_lst}
new_model = deepcopy(model)
for key, value in original_attr.items():
setattr(new_model, key, value)
return new_model


@dump_elapsed_time("Pass auto-tune")
def autotune(
model: torch.nn.Module,
Expand Down Expand Up @@ -81,7 +90,7 @@ def autotune(
best_quant_model = None
eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args)
config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config)
baseline: float = eval_func_wrapper.evaluate(deepcopy(model))
baseline: float = eval_func_wrapper.evaluate(_deepcopy_warp(model))
tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
for trial_index, quant_config in enumerate(config_loader, 1):
Expand All @@ -90,7 +99,7 @@ def autotune(
logger.info(quant_config.to_dict())
# !!! Make sure to use deepcopy only when inplace is set to `True`.
q_model = quantize(
deepcopy(model),
_deepcopy_warp(model),
quant_config=quant_config,
run_fn=run_fn,
run_args=run_args,
Expand All @@ -112,7 +121,7 @@ def autotune(
best_quant_config: BaseConfig = best_trial_record.quant_config
# !!! Make sure to use deepcopy only when inplace is set to `True`.
q_model = quantize(
deepcopy(model),
_deepcopy_warp(model),
quant_config=best_quant_config,
run_fn=run_fn,
run_args=run_args,
Expand Down
19 changes: 19 additions & 0 deletions test/3x/torch/quantization/test_pt2e_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,22 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex):
opt_model = torch.compile(converted_model)
out = opt_model(*example_inputs)
assert out is not None

@pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5")
@pytest.mark.parametrize("half_precision_dtype", ["fp16", "bf16"])
def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, force_not_import_ipex):
# config1: int8 for all
# config2: half precision for linear
from neural_compressor.torch.quantization.config import INT8StaticQuantConfig
from neural_compressor.torch.quantization.autotune import autotune, TuningConfig
config1 = INT8StaticQuantConfig()
config2 = INT8StaticQuantConfig().set_local("fc1", StaticQuantConfig(w_dtype=half_precision_dtype, act_dtype=half_precision_dtype))
tune_config = TuningConfig(config_set=[config1, config2], tolerable_loss=-0.1)
def fake_eval_fn(model):
return 1.0
def run_fn(model):
for i in range(2):
model(*example_inputs)
model, example_inputs = self.build_model_include_conv_and_linear()
model = export(model, example_inputs=example_inputs)
qmodel = autotune(model=model, tune_config=tune_config, eval_fn=fake_eval_fn,run_fn=run_fn, example_inputs=example_inputs)
Loading