-
Notifications
You must be signed in to change notification settings - Fork 1
/
inc_quantize_model.py
73 lines (54 loc) · 2.05 KB
/
inc_quantize_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Environment Setting
Intel Optimized TensorFlow 2.5.0 and later require to set environment variable TF_ENABLE_MKL_NATIVE_FORMAT=0 before running Intel® Neural Compressor quantize Fp32 model or deploying the quantized model.
"""
import sys
try:
import neural_compressor as inc
print("neural_compressor version {}".format(inc.__version__))
except:
try:
import lpot as inc
print("LPOT version {}".format(inc.__version__))
except:
import ilit as inc
print("iLiT version {}".format(inc.__version__))
if inc.__version__ == '1.2':
print("This script doesn't support LPOT 1.2, please install LPOT 1.1, 1.2.1 or newer")
sys.exit(1)
import alexnet
import math
import mnist_dataset
def save_int8_frozen_pb(q_model, path):
from tensorflow.python.platform import gfile
f = gfile.GFile(path, 'wb')
f.write(q_model.as_graph_def().SerializeToString())
print("Save to {}".format(path))
class Dataloader(object):
def __init__(self, batch_size):
self.batch_size = batch_size
def __iter__(self):
x_train, y_train, label_train, x_test, y_test, label_test = mnist_dataset.read_data()
batch_nums = math.ceil(len(x_test) / self.batch_size)
for i in range(batch_nums - 1):
begin = i * self.batch_size
end = (i + 1) * self.batch_size
yield x_test[begin: end], label_test[begin: end]
begin = (batch_nums - 1) * self.batch_size
yield x_test[begin:], label_test[begin:]
def auto_tune(input_graph_path, yaml_config, batch_size):
fp32_graph = alexnet.load_pb(input_graph_path)
quan = inc.Quantization(yaml_config)
dataloader = Dataloader(batch_size)
q_model = quan(
fp32_graph,
q_dataloader=dataloader,
eval_func=None,
eval_dataloader=dataloader)
return q_model
yaml_file = "alexnet.yaml"
batch_size = 200
fp32_frozen_pb_file = "fp32_frozen.pb"
int8_pb_file = "alexnet_int8_model.pb"
q_model = auto_tune(fp32_frozen_pb_file, yaml_file, batch_size)
save_int8_frozen_pb(q_model, int8_pb_file)