Skip to content

Commit

Permalink
fix coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill committed Apr 2, 2022
1 parent da5ddf7 commit 82f2b71
Show file tree
Hide file tree
Showing 4 changed files with 376 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -387,23 +387,27 @@ def quantize(self):
break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))

if self._round_type == 'adaround':
self._adaround_apply()

self._reset_activation_persistable()
if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
self._quantized_threshold[var_name] = \
np.array(self._quantized_var_avg[var_name]).mean()
if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()
if self._algo in ["KL", "abs_max", "hist", "avg", "mse", "emd"]:
self._update_program()
else:

if self._round_type == 'adaround':
self._adaround_apply()

self._reset_activation_persistable()

if self._algo is 'min_max':
self._save_input_threhold()
else:
self._update_program()

# save out_threshold for quantized ops.
if not self._onnx_format:
self._save_output_threshold()

if any(op_type in self._quantizable_op_type
for op_type in self._dynamic_quantize_op_type):
self._collect_dynamic_quantize_op_threshold(
Expand All @@ -428,6 +432,7 @@ def quantize(self):
return self._program

def _adaround_apply(self):
assert self._algo != "min_max", "The algo should not be min_max."
if self._algo in ["KL", "hist"]:
scale_dict = self._quantized_var_threshold
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def generate_quantized_model(self,
is_use_cache_file=False,
is_optimize_model=False,
batch_size=10,
batch_nums=10):
batch_nums=10,
onnx_format=False):

place = fluid.CPUPlace()
exe = fluid.Executor(place)
Expand All @@ -190,14 +191,28 @@ def generate_quantized_model(self,
round_type=round_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model_path)

def run_test(self, model_name, model_url, model_md5, data_name, data_url,
data_md5, algo, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, infer_iterations, quant_iterations):
def run_test(self,
model_name,
model_url,
model_md5,
data_name,
data_url,
data_md5,
algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
infer_iterations,
quant_iterations,
onnx_format=False):
fp32_model_path = self.download_model(model_url, model_md5, model_name)
fp32_model_path = os.path.join(fp32_model_path, model_name)

Expand All @@ -211,10 +226,10 @@ def run_test(self, model_name, model_url, model_md5, data_name, data_url,

print("Start post training quantization for {0} on {1} samples ...".
format(model_name, quant_iterations))
self.generate_quantized_model(fp32_model_path, data_path, algo,
round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file,
is_optimize_model, quant_iterations)
self.generate_quantized_model(
fp32_model_path, data_path, algo, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
quant_iterations, onnx_format)

print("Start INT8 inference for {0} on {1} samples ...".format(
model_name, infer_iterations))
Expand Down Expand Up @@ -278,5 +293,42 @@ def test_post_training_kl(self):
diff_threshold, infer_iterations, quant_iterations)


class TestPostTrainingKLForMnistONNXFormat(TestPostTrainingQuantization):
def test_post_training_kl_onnx_format(self):
model_name = "nlp_lstm_fp32_model"
model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz"
model_md5 = "519b8eeac756e7b4b7bcb2868e880452"
data_name = "quant_lstm_input_data"
data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5 = "add84c754e9b792fea1fbd728d134ab7"
algo = "KL"
round_type = "round"
quantizable_op_type = ["mul", "lstm"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0.01
infer_iterations = 100
quant_iterations = 10
onnx_format = True
self.run_test(
model_name,
model_url,
model_md5,
data_name,
data_url,
data_md5,
algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
infer_iterations,
quant_iterations,
onnx_format=onnx_format)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def generate_quantized_model(self,
is_use_cache_file=False,
is_optimize_model=False,
batch_size=10,
batch_nums=10):
batch_nums=10,
onnx_format=False):

place = fluid.CPUPlace()
exe = fluid.Executor(place)
Expand All @@ -134,6 +135,7 @@ def generate_quantized_model(self,
round_type=round_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model_path)
Expand All @@ -151,7 +153,8 @@ def run_test(self,
diff_threshold,
batch_size=10,
infer_iterations=10,
quant_iterations=5):
quant_iterations=5,
onnx_format=False):

origin_model_path = self.download_model(data_url, data_md5, model_name)
origin_model_path = os.path.join(origin_model_path, model_name)
Expand All @@ -166,7 +169,7 @@ def run_test(self,
self.generate_quantized_model(origin_model_path, algo, round_type,
quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model,
batch_size, quant_iterations)
batch_size, quant_iterations, onnx_format)

print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size))
Expand Down Expand Up @@ -335,5 +338,72 @@ def test_post_training_mse(self):
infer_iterations, quant_iterations)


class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
def test_post_training_mse_onnx_format(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
onnx_format = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(
model_name,
data_url,
data_md5,
algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_size,
infer_iterations,
quant_iterations,
onnx_format=onnx_format)


class TestPostTrainingmseForMnistONNXFormatFullQuant(
TestPostTrainingQuantization):
def test_post_training_mse_onnx_format_full_quant(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = True
is_use_cache_file = False
is_optimize_model = False
onnx_format = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(
model_name,
data_url,
data_md5,
algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_size,
infer_iterations,
quant_iterations,
onnx_format=onnx_format)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 82f2b71

Please sign in to comment.