Skip to content

Commit

Permalink
[cherry-pick]Add progress bar and speed up Quantization Pass (#43454)
Browse files Browse the repository at this point in the history
* Add progress bar and speed up Quantization Pass

* fix typo
  • Loading branch information
yghstill authored Jun 16, 2022
1 parent 7e940b8 commit abb0b2d
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
import logging
import numpy as np
import shutil
try:
from tqdm import tqdm
except:
from .utils import tqdm
from inspect import isgeneratorfunction
from .... import io
from .... import core
Expand Down Expand Up @@ -357,38 +361,40 @@ def quantize(self):
self._set_activation_persistable()

if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...")
batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
self._init_sampling_act_histogram()

batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
self._sampling()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish preparation stage, all batch:" + str(batch_id))
self._init_sampling_act_histogram()

_logger.info("Sampling stage ...")
batch_id = 0
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))

if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
Expand Down Expand Up @@ -823,8 +829,9 @@ def _collect_activation_abs_min_max(self):
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
if var_name not in self._sampling_act_abs_min_max:
self._sampling_act_abs_min_max[
var_name] = [min_value, max_value]
self._sampling_act_abs_min_max[var_name] = [
min_value, max_value
]
else:
if min_value < self._sampling_act_abs_min_max[var_name][0]:
self._sampling_act_abs_min_max[var_name][0] = min_value
Expand Down
Loading

0 comments on commit abb0b2d

Please sign in to comment.