Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions neural_compressor/torch/algorithms/pt2e_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver, PlaceholderObserver
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer

Expand Down Expand Up @@ -48,19 +53,23 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals
"placeholder": PlaceholderObserver,
"minmax": MinMaxObserver,
"kl": HistogramObserver,
"per_channel_minmax": PerChannelMinMaxObserver,
}
# Force to use placeholder observer for dynamic quantization
if is_dynamic:
algo = "placeholder"
# algo
observer_or_fake_quant_ctr = observer_mapping[algo]
if f"{granularity}_{algo}" in observer_mapping:
observer_or_fake_quant_ctr = observer_mapping[f"{granularity}_{algo}"]
else:
observer_or_fake_quant_ctr = observer_mapping[algo]
# qscheme
qscheme = qscheme_mapping[granularity][sym]
quantization_spec = QuantizationSpec(
dtype=select_dtype,
quant_min=min_max_mapping[select_dtype][0],
quant_max=min_max_mapping[select_dtype][1],
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
ch_axis=0,
qscheme=qscheme,
is_dynamic=is_dynamic,
)
Expand Down
7 changes: 5 additions & 2 deletions test/3x/torch/quantization/test_pt2e_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return exported_model, example_inputs

@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
def test_quantize_simple_model(self, force_not_import_ipex):
@pytest.mark.parametrize("granularity", ["per_tensor", "per_channel"])
def test_quantize_simple_model(self, granularity, force_not_import_ipex):
from neural_compressor.torch.quantization import StaticQuantConfig

model, example_inputs = self.build_simple_torch_model_and_example_inputs()
float_model_output = model(*example_inputs)
quant_config = None
Expand All @@ -107,7 +110,7 @@ def calib_fn(model):
for i in range(4):
model(*example_inputs)

quant_config = get_default_static_config()
quant_config = StaticQuantConfig(w_granularity=granularity)
q_model = quantize(model=model, quant_config=quant_config, run_fn=calib_fn)
from torch._inductor import config

Expand Down