Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix unit tests for python3
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Jan 16, 2018
1 parent 5c81473 commit 2862776
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
20 changes: 13 additions & 7 deletions benchmark/python/quantization/benchmark_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
from mxnet.test_utils import check_speed


def quantize_int8_helper(data):
min_data = mx.nd.min(data)
max_data = mx.nd.max(data)
return mx.nd.contrib.quantize(data, min_data, max_data, out_type='int8')


def benchmark_convolution(data_shape, kernel, num_filter, pad, stride, no_bias=True, layout='NCHW', repeats=20):
ctx_gpu = mx.gpu(0)
data = mx.sym.Variable(name="data", shape=data_shape)
data = mx.sym.Variable(name="data", shape=data_shape, dtype='float32')
# conv cudnn
conv_cudnn = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride,
no_bias=no_bias, layout=layout, cudnn_off=False, name="conv_cudnn")
Expand All @@ -28,12 +34,12 @@ def benchmark_convolution(data_shape, kernel, num_filter, pad, stride, no_bias=T
kernel=kernel, num_filter=num_filter, pad=pad, stride=stride,
no_bias=no_bias, layout=layout, cudnn_off=False,
name='quantized_conv2d')
qargs = {qdata.name: mx.quantization.quantize(input_data)[0],
min_data.name: mx.quantization.quantize(input_data)[1],
max_data.name: mx.quantization.quantize(input_data)[2],
weight.name: mx.quantization.quantize(args[conv_weight_name])[0],
min_weight.name: mx.quantization.quantize(args[conv_weight_name])[1],
max_weight.name: mx.quantization.quantize(args[conv_weight_name])[2]}
qargs = {qdata.name: quantize_int8_helper(input_data)[0],
min_data.name: quantize_int8_helper(input_data)[1],
max_data.name: quantize_int8_helper(input_data)[2],
weight.name: quantize_int8_helper(args[conv_weight_name])[0],
min_weight.name: quantize_int8_helper(args[conv_weight_name])[1],
max_weight.name: quantize_int8_helper(args[conv_weight_name])[2]}
qconv_time = check_speed(sym=quantized_conv2d, location=qargs, ctx=ctx_gpu, N=repeats,
grad_req='null', typ='forward') * 1000

Expand Down
14 changes: 7 additions & 7 deletions python/mxnet/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,15 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255):
th = max(abs(min_val), abs(max_val))

hist, hist_edeges = np.histogram(arr, bins=num_bins, range=(-th, th))
zero_bin_idx = num_bins / 2
num_half_quantized_bins = num_quantized_bins / 2
assert np.allclose(hist_edeges[int(zero_bin_idx)] + hist_edeges[int(zero_bin_idx + 1)], 0, rtol=1e-5, atol=1e-7)
zero_bin_idx = num_bins // 2
num_half_quantized_bins = num_quantized_bins // 2
assert np.allclose(hist_edeges[zero_bin_idx] + hist_edeges[zero_bin_idx + 1], 0, rtol=1e-5, atol=1e-7)

thresholds = np.zeros(num_bins / 2 + 1 - num_quantized_bins / 2)
thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2)
divergence = np.zeros_like(thresholds)
quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32)
for i in range(num_quantized_bins / 2,
num_bins / 2 + 1): # i means the number of bins on half axis excluding the zero bin
for i in range(num_quantized_bins // 2,
num_bins // 2 + 1): # i means the number of bins on half axis excluding the zero bin
p_bin_idx_start = zero_bin_idx - i
p_bin_idx_stop = zero_bin_idx + i + 1
thresholds[i - num_half_quantized_bins] = hist_edeges[p_bin_idx_stop]
Expand All @@ -232,7 +232,7 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255):
is_nonzeros = (sliced_nd_hist != 0).astype(np.int32)

# calculate how many bins should be merged to generate quantized distribution q
num_merged_bins = p.size / num_quantized_bins
num_merged_bins = p.size // num_quantized_bins
# merge hist into num_quantized_bins bins
for j in range(num_quantized_bins):
start = j * num_merged_bins
Expand Down
3 changes: 2 additions & 1 deletion tests/python/unittest/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import mxnet as mx
import numpy as np
from mxnet.test_utils import assert_almost_equal, rand_ndarray, rand_shape_nd, same
from mxnet.test_utils import assert_almost_equal, rand_ndarray, rand_shape_nd, same, set_default_context


def test_quantize_float32_to_int8():
Expand Down Expand Up @@ -348,5 +348,6 @@ def get_threshold(nd):


if __name__ == "__main__":
set_default_context(mx.gpu(0))
import nose
nose.runmodule()

0 comments on commit 2862776

Please sign in to comment.