forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Quantization] 8bit Quantization and GPU Support
[Quantization] CuDNN 8bit quantized relu v0.1 [Quantization] CuDNN 8bit quantized max_pool v0.1 [Quantization] CuDNN 8bit quantized lrn v0.1 [Quantization] CuDNN 8bit quantized convolution v0.1 [Quantization] CuDNN 8bit quantized fully connected v0.1 [Quantization] Small fix [Quantization] Implement backward method [Quantization] Convolution backward method [Quantization] Add range for matmul and conv [Quantization] New types in ndarray.py [Quantization] 8bit conv works [Quantization] conv support multiple type [Quantization] matmul works now [Quantization] matmul works well [Quantization] efactor quantization operators [Quantization] Op: quantize_down_and_shrink_range [Quantization] Complete quantize_graph_pass [Quantization] Add example [Quantization] Take zero-center quantize, accuracy fixed [Quantization] Multiple layers MLP pass [Quantization] Make quantized_conv same as Convolution [Quantization] quantized_conv works [Quantization] Fix bug [Quantization] lenet works now [Quantization] Add quantized_flatten [Quantization] Quantized max pool works well [Quantization] Make quantized_conv support NHWC [Quantization] add max_pool [Quantization] add ignore_symbols [Quantization] Save change [Quantization] Reorganize tests, 8 layers resnet works on cifar [Quantization] Support for 'NHWC' max pool [Quantization] Support for 'NHWC' quantized max pool [Quantization] Fix speed of quantize_down_and_shrink_range [Quantization] script for resnet on imagenet [Quantization] refactor for quantize offline [Quantization] Fix infershape [Quantization] Update test [Quantization] Update example [Quantization] Fix build error
- Loading branch information
1 parent
e33da52
commit 8cebd6a
Showing
54 changed files
with
4,276 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import mxnet as mx | ||
import numpy as np | ||
import logging | ||
import os | ||
from sklearn.datasets import fetch_mldata | ||
from mxnet.quantization import * | ||
import mxnet.ndarray as nd | ||
|
||
logger = logging.getLogger() | ||
logger.setLevel(logging.DEBUG) | ||
|
||
INFERENCE = False | ||
no_bias = True | ||
batch_size = 32 | ||
name = "conv_mnist" | ||
|
||
data = mx.symbol.Variable('data') | ||
conv1 = mx.symbol.Convolution(data=data, kernel=(5, 5), | ||
num_filter=20, no_bias=True, layout='NHWC') | ||
relu1 = mx.symbol.relu(data=conv1) | ||
flatten = mx.symbol.flatten(data=relu1) | ||
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=10) | ||
conv_net = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') | ||
|
||
# prepare data | ||
mnist = fetch_mldata('MNIST original') | ||
np.random.seed(1234) # set seed for deterministic ordering | ||
p = np.random.permutation(mnist.data.shape[0]) | ||
X = mnist.data[p].reshape(70000, 28, 28, 1) | ||
pad = np.zeros(shape=(70000, 28, 28, 3)) | ||
X = np.concatenate([X, pad], axis=3) | ||
Y = mnist.target[p] | ||
|
||
X = X.astype(np.float32)/255 | ||
X_train = X[:60000] | ||
X_test = X[60000:] | ||
Y_train = Y[:60000] | ||
Y_test = Y[60000:] | ||
|
||
train_iter = mx.io.NDArrayIter(X_train, Y_train, batch_size=batch_size) | ||
val_iter = mx.io.NDArrayIter(X_test, Y_test, batch_size=batch_size) | ||
|
||
# create a trainable module on GPU 0 | ||
model = mx.mod.Module(symbol=conv_net, context=mx.gpu(0)) | ||
if not INFERENCE: | ||
# train with the same | ||
model.fit(train_iter, | ||
eval_data=val_iter, | ||
optimizer='sgd', | ||
optimizer_params={'learning_rate':0.1}, | ||
eval_metric='acc', | ||
batch_end_callback = mx.callback.Speedometer(batch_size, 100), | ||
num_epoch=10) | ||
model.save_checkpoint(name, 10) | ||
else: | ||
_, arg_params, aux_params = mx.model.load_checkpoint(name, 10) | ||
model.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) | ||
model.set_params(arg_params=arg_params, aux_params=aux_params) | ||
|
||
|
||
test_iter = val_iter | ||
# predict accuracy for conv net | ||
acc = mx.metric.Accuracy() | ||
print('Accuracy: {}%'.format(model.score(test_iter, acc)[0][1]*100)) | ||
|
||
quantized_conv_net = quantize_graph(conv_net) | ||
print(quantized_conv_net.debug_str()) | ||
params = model.get_params()[0] | ||
|
||
def test(symbol): | ||
model = mx.model.FeedForward( | ||
symbol, | ||
ctx=mx.gpu(0), | ||
arg_params=params) | ||
print 'Accuracy:', model.score(test_iter)*100, '%' | ||
|
||
print('origin:') | ||
test(conv_net) | ||
print('after quantization:') | ||
test(quantized_conv_net) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#! /bin/sh | ||
python -i resnet_imagenet.py --model=imagenet1k-resnet-152 --data-val=./data/imagenet/imagenet1k-val.rec --gpus=0 --data-nthreads=60 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import mxnet as mx | ||
import numpy as np | ||
import logging | ||
import os | ||
from sklearn.datasets import fetch_mldata | ||
from mxnet.quantization import * | ||
import mxnet.ndarray as nd | ||
|
||
logger = logging.getLogger() | ||
logger.setLevel(logging.DEBUG) | ||
|
||
no_bias = True | ||
batch_size = 32 | ||
name = 'lenet_mnist' | ||
|
||
data = mx.symbol.Variable('data') | ||
conv1 = mx.symbol.Convolution(data=data, kernel=(5, 5), num_filter=20, no_bias=True) | ||
relu1 = mx.symbol.relu(data=conv1) | ||
pool1 = mx.symbol.max_pool(data=relu1, kernel=(2, 2), stride=(2, 2)) | ||
|
||
conv2 = mx.symbol.Convolution(data=pool1, kernel=(5, 5), num_filter=48, no_bias=True) | ||
relu2 = mx.symbol.relu(data=conv2) | ||
pool2 = mx.symbol.max_pool(data=relu2, kernel=(2, 2), stride=(2, 2)) | ||
|
||
flatten = mx.symbol.flatten(data=pool2) | ||
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500, no_bias=True) | ||
relu3 = mx.symbol.relu(data=fc1) | ||
|
||
fc2 = mx.symbol.FullyConnected(data=relu3, num_hidden=10, no_bias=True) | ||
lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax') | ||
|
||
|
||
# prepare data | ||
mnist = fetch_mldata('MNIST original') | ||
np.random.seed(1234) # set seed for deterministic ordering | ||
p = np.random.permutation(mnist.data.shape[0]) | ||
X = mnist.data[p].reshape(70000, 1, 28, 28) | ||
pad = np.zeros(shape=(70000, 3, 28, 28)) | ||
X = np.concatenate([X, pad], axis=1) | ||
Y = mnist.target[p] | ||
|
||
X = X.astype(np.float32)/255 | ||
X_train = X[:60000] | ||
X_test = X[60000:] | ||
Y_train = Y[:60000] | ||
Y_test = Y[60000:] | ||
|
||
train_iter = mx.io.NDArrayIter(X_train, Y_train, batch_size=batch_size) | ||
val_iter = mx.io.NDArrayIter(X_test, Y_test, batch_size=batch_size) | ||
|
||
# create a trainable module on GPU 0 | ||
lenet_model = mx.mod.Module(symbol=lenet, context=mx.gpu(0)) | ||
# train with the same | ||
# lenet_model.fit(train_iter, | ||
# eval_data=val_iter, | ||
# optimizer='sgd', | ||
# optimizer_params={'learning_rate':0.1}, | ||
# eval_metric='acc', | ||
# batch_end_callback = mx.callback.Speedometer(batch_size, 100), | ||
# num_epoch=10) | ||
# lenet_model.save_checkpoint(name, 10) | ||
sym, arg_params, aux_params = mx.model.load_checkpoint(name, 10) | ||
lenet_model.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) | ||
lenet_model.set_params(arg_params=arg_params, aux_params=aux_params) | ||
|
||
|
||
test_iter = val_iter | ||
# predict accuracy for lenet | ||
acc = mx.metric.Accuracy() | ||
print('Accuracy: {}%'.format(lenet_model.score(test_iter, acc)[0][1]*100)) | ||
|
||
quantized_lenet = quantize_graph(lenet) | ||
print(quantized_lenet.debug_str()) | ||
params = lenet_model.get_params()[0] | ||
|
||
def test(symbol): | ||
model = mx.model.FeedForward( | ||
symbol, | ||
ctx=mx.gpu(0), | ||
arg_params=params) | ||
print 'Accuracy:', model.score(test_iter)*100, '%' | ||
|
||
print('origin:') | ||
test(lenet) | ||
print('after quantization:') | ||
test(quantized_lenet) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import mxnet as mx | ||
import numpy as np | ||
import logging | ||
from sklearn.datasets import fetch_mldata | ||
from mxnet.quantization import * | ||
|
||
logger = logging.getLogger() | ||
logger.setLevel(logging.DEBUG) | ||
|
||
name = 'mlp_mnist' | ||
no_bias = True | ||
batch_size = 32 | ||
INFERENCE = True | ||
|
||
data = mx.symbol.Variable('data') | ||
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=32, no_bias=no_bias) | ||
act1 = mx.symbol.relu(data = fc1, name='act1') | ||
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, no_bias=no_bias) | ||
act2 = mx.symbol.relu(data = fc2) | ||
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10, no_bias=no_bias) | ||
mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') | ||
|
||
print(mlp.list_arguments()) | ||
|
||
|
||
# prepare data | ||
mnist = fetch_mldata('MNIST original') | ||
np.random.seed(1234) # set seed for deterministic ordering | ||
p = np.random.permutation(mnist.data.shape[0]) | ||
X = mnist.data[p] | ||
Y = mnist.target[p] | ||
|
||
X = X.astype(np.float32)/255 | ||
X_train = X[:60000] | ||
X_test = X[60000:] | ||
Y_train = Y[:60000] | ||
Y_test = Y[60000:] | ||
|
||
train_iter = mx.io.NDArrayIter(X_train, Y_train, batch_size=batch_size) | ||
test_iter = mx.io.NDArrayIter(X_test, Y_test, batch_size=batch_size) | ||
val_iter = test_iter | ||
|
||
model = mx.mod.Module(symbol=mlp, context=mx.gpu(0)) | ||
if not INFERENCE: | ||
model.fit(train_iter, | ||
eval_data=val_iter, | ||
optimizer='sgd', | ||
optimizer_params={'learning_rate':0.1}, | ||
eval_metric='acc', | ||
batch_end_callback = mx.callback.Speedometer(batch_size, 200), | ||
num_epoch=10) | ||
model.save_checkpoint(name, 10) | ||
else: | ||
_, arg_params, aux_params = mx.model.load_checkpoint(name, 10) | ||
model.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) | ||
model.set_params(arg_params=arg_params, aux_params=aux_params) | ||
|
||
acc = mx.metric.Accuracy() | ||
print('Accuracy: {}%'.format(model.score(test_iter, acc)[0][1]*100)) | ||
|
||
|
||
quantized_mlp = quantize_graph(mlp) | ||
print(quantized_mlp.debug_str()) | ||
params = model.get_params()[0] | ||
|
||
def test(symbol): | ||
model = mx.model.FeedForward( | ||
symbol, | ||
ctx=mx.gpu(0), | ||
arg_params=params) | ||
print 'Accuracy:', model.score(test_iter)*100, '%' | ||
|
||
print('origin:') | ||
test(mlp) | ||
print('after quantization:') | ||
test(quantized_mlp) |
Oops, something went wrong.