-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Mxnet-1397] Support symbolic api for requantize and dequantize #14749
Changes from 3 commits
78c4046
21a43e2
3bb9f4c
25ef1a1
542256d
64426de
ea4927a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,10 @@ inference accuracy. | |
.set_attr_parser(ParamParser<RequantizeParam>) | ||
.set_num_inputs(3) | ||
.set_num_outputs(3) | ||
.set_attr<nnvm::FListInputNames>("FListInputNames", | ||
[](const NodeAttrs& attrs) { | ||
return std::vector<std::string>{"data", "min_range", "max_range"}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as above. |
||
}) | ||
.set_attr<mxnet::FInferShape>("FInferShape", QuantizeShape) | ||
.set_attr<nnvm::FInferType>("FInferType", RequantizeType) | ||
.set_attr<FInferStorageType>("FInferStorageType", RequantizeStorageType) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,19 +63,45 @@ def test_quantize_float32_to_int8(): | |
|
||
@with_seed() | ||
def test_dequantize_int8_to_float32(): | ||
|
||
def get_test_data(real_range, qdata_np): | ||
qdata = mx.nd.array(qdata_np, dtype=np.int8) | ||
min_range = mx.nd.array([-real_range], dtype=np.float32) | ||
max_range = mx.nd.array([real_range], dtype=np.float32) | ||
return qdata, min_range, max_range | ||
|
||
def baseline_dequantization(qdata, real_range, qdata_np): | ||
quantized_range = 127.0 | ||
scale = real_range / quantized_range | ||
data_np = qdata_np * scale | ||
return data_np | ||
|
||
def test_nd_array_dequantization(qdata, min_range, max_range, expected_result): | ||
data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32') | ||
assert data.dtype == np.float32 | ||
assert_almost_equal(data.asnumpy(), expected_result) | ||
|
||
def test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result): | ||
sym_data = mx.sym.Variable('data') | ||
sym_min_range = mx.sym.Variable('min_range') | ||
sym_max_range = mx.sym.Variable('max_range') | ||
dequant = mx.sym.contrib.dequantize(sym_data, sym_min_range, | ||
sym_max_range, out_type='float32') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: indent? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
out = dequant.bind(ctx=mx.cpu(), args={'data':qdata, 'min_range':min_range, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use ctx=mx.current_context() so this test can cover both CPU and GPU computation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
'max_range':max_range}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
data = out.forward()[0] | ||
assert data.dtype == np.float32 | ||
assert_almost_equal(data.asnumpy(), expected_result) | ||
|
||
real_range = 402.3347 | ||
shape = rand_shape_nd(4) | ||
qdata_np = np.random.uniform(low=-127, high=127, size=shape).astype(dtype=np.int8) | ||
qdata = mx.nd.array(qdata_np, dtype=np.int8) | ||
real_range = 402.3347 | ||
min_range = mx.nd.array([-real_range], dtype=np.float32) | ||
max_range = mx.nd.array([real_range], dtype=np.float32) | ||
data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32') | ||
quantized_range = 127.0 | ||
scale = real_range / quantized_range | ||
assert data.dtype == np.float32 | ||
data_np = qdata_np * scale | ||
assert_almost_equal(data.asnumpy(), data_np) | ||
|
||
qdata, min_range, max_range = get_test_data(real_range, qdata_np) | ||
expected_result = baseline_dequantization(qdata, real_range, qdata_np) | ||
# test nd array implementation. | ||
test_nd_array_dequantization(qdata, min_range, max_range, expected_result) | ||
# test symbolic api implementaion. | ||
test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result) | ||
|
||
@with_seed() | ||
def test_requantize_int32_to_int8(): | ||
|
@@ -124,7 +150,40 @@ def check_requantize(shape, min_calib_range=None, max_calib_range=None): | |
assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np, atol = 1) | ||
assert_almost_equal(min_output.asnumpy(), np.array([min_output_np])) | ||
assert_almost_equal(max_output.asnumpy(), np.array([max_output_np])) | ||
|
||
def check_requantize_with_symbol(shape, min_calib_range=None, max_calib_range=None): | ||
qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0, shape=shape).astype('int32') | ||
min_range = mx.nd.array([-1010.0]) | ||
max_range = mx.nd.array([1020.0]) | ||
sym_data = mx.sym.Variable('data') | ||
sym_min_range = mx.sym.Variable('min_range') | ||
sym_max_range = mx.sym.Variable('max_range') | ||
if min_calib_range is None or max_calib_range is None: | ||
requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range) | ||
out = requant.bind(ctx=mx.cpu(), args={'data':qdata, 'min_range':min_range, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
'max_range':max_range}) | ||
qdata_int8, min_output, max_output = out.forward() | ||
else: | ||
requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range, | ||
min_calib_range, max_calib_range) | ||
out = requant.bind(ctx=mx.cpu(), args={'data':qdata, 'min_range':min_range, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use ctx=mx.current_context() so this test can cover both CPU and GPU computation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
'max_range':max_range}) | ||
qdata_int8, min_output, max_output = out.forward() | ||
|
||
qdata_int8_np, min_output_np, max_output_np = requantize_baseline(qdata.asnumpy(), min_range.asscalar(), | ||
max_range.asscalar(), | ||
min_calib_range=min_calib_range, | ||
max_calib_range=max_calib_range) | ||
assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np) | ||
assert_almost_equal(min_output.asnumpy(), np.array([min_output_np])) | ||
assert_almost_equal(max_output.asnumpy(), np.array([max_output_np])) | ||
|
||
# test with symbol API. | ||
check_requantize_with_symbol((3, 4, 10, 10)) | ||
check_requantize_with_symbol((32, 3, 23, 23)) | ||
check_requantize_with_symbol((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0) | ||
check_requantize_with_symbol((32, 3, 23, 23), min_calib_range=-134.349, max_calib_range=523.43) | ||
# Test with nd array API | ||
check_requantize((3, 4, 10, 10)) | ||
check_requantize((32, 3, 23, 23)) | ||
check_requantize((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these names will be exposed to front end users, I hope they can align with other quantization operators. In quantized convolution and quantized FC, I see they are
min_data
andmax_data
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the names are documented well in most of the quantized ops I think it should be ok. Especially in quantized conv and FC there are too many quantized parameters, I think it is easier to understand the API with
min_data
andmax_data