Skip to content

Commit

Permalink
[MKL-DNN] Enable s8 support for inner product and 3d input with flatt…
Browse files Browse the repository at this point in the history
…en=false (apache#14466)

* support 3d input innerproduct with flatten=false

* simplify test case and improve error msg

* enable s8s8 inner product

* improve code style

* add tests and improve error msg
  • Loading branch information
xinyu-intel authored and ZhennanQin committed Apr 3, 2019
1 parent 35a4507 commit 3b5276c
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 48 deletions.
10 changes: 8 additions & 2 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
namespace mxnet {
namespace op {

bool SupportMKLDNNFC(const NDArray& input) {
int ndim = input.shape().ndim();
return input.dtype() == mshadow::kFloat32 && (ndim >= 1 && ndim <= 4) &&
input.storage_type() == kDefaultStorage;
}

static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
Expand Down Expand Up @@ -94,7 +100,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
#if MXNET_USE_MKLDNN == 1
if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNNFC(inputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNFCForward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
Expand Down Expand Up @@ -141,7 +147,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNNFC(inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNFCBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,6 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,

NDArray data = in_data[fullc::kData];
NDArray weight = in_data[fullc::kWeight];
const TShape &ishape = data.shape();

CHECK(data.dtype() == mshadow::kUint8)
<< "MKLDNNQuantizedFullyConnected Op only supports uint8 for now, but got "
<< mxnet::op::type_string(data.dtype());

if (ishape.ndim() != 2) {
CHECK(param.flatten)
<< "QuantizedFullyConnected Op only supports flatten=true when ishape.ndim()!=2 for now.";
data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())));
}

const float min_data =
in_data[num_inputs + quantized_fc_enum::kDataMin].data().dptr<float>()[0];
Expand Down
33 changes: 19 additions & 14 deletions src/operator/quantization/quantized_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
CHECK(!shape_is_none(in_shape->at(0)))
<< "QuantizedFullyConnectedOp input data shape must be given";
const mxnet::TShape& dshape = in_shape->at(0);
mxnet::TShape wshape = Shape2(param.num_hidden, dshape.ProdShape(1, dshape.ndim()));
if (dshape.ndim() != 2) {
CHECK(param.flatten)
<< "QuantizedFullyConnectedOp only supports flatten=true when ishape.ndim()!=2 for now. ";
index_t num_input;
if (!param.flatten) {
num_input = dshape[dshape.ndim() - 1];
} else {
num_input = dshape.ProdShape(1, dshape.ndim());
}

TShape wshape = Shape2(param.num_hidden, num_input);
SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape);
if (!param.no_bias) {
mxnet::TShape bshape = Shape1(param.num_hidden);
Expand All @@ -65,7 +68,13 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape{1});
}

SHAPE_ASSIGN_CHECK(*out_shape, 0, mxnet::TShape({dshape[0], wshape[0]}));
if (!param.flatten) {
TShape result_shape(dshape);
result_shape[dshape.ndim() - 1] = param.num_hidden;
SHAPE_ASSIGN_CHECK(*out_shape, 0, result_shape);
} else {
SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden));
}
SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1}));
SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1}));
return true;
Expand All @@ -80,9 +89,9 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_type->size(), 3U);

#if MXNET_USE_MKLDNN == 1
// TODO(ciyong): currently, only uint8 fully_connected is upported,
// int8 fully_connected will be supported after mkldnn v0.18
TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kUint8);
CHECK(in_type->at(0) == mshadow::kInt8 || in_type->at(0) == mshadow::kUint8)
<< "QuantizedFullyConnected only supports int8/uint8 input, while "
<< in_type->at(0) << " is given.";
#else
TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8);
#endif
Expand Down Expand Up @@ -182,7 +191,8 @@ void QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs& attrs,

if (dshape.ndim() != 2)
CHECK(param.flatten)
<< "QuantizedFullyConnectedOp only supports flatten=true when input_shape!=2 for now. ";
<< "QuantizedFullyConnectedForwardCPU only supports flatten=true "
<< "when dshape.ndim() != 2 for now.";

Tensor<cpu, 2, int8_t> weight = in_data[fullc::kWeight].get<cpu, 2, int8_t>(s);
Tensor<cpu, 2, int8_t> data = in_data[fullc::kData].get_with_shape<cpu, 2, int8_t>(
Expand Down Expand Up @@ -276,11 +286,6 @@ void QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
if (in_data[fullc::kData].dtype() == mshadow::kInt8) {
FallBackCompute(QuantizedFullyConnectedForwardCPU, attrs, ctx, in_data, req, out_data);
return;
}

MKLDNNQuantizedFullyConnectedForward(attrs, ctx, in_data, req, out_data);
}
#endif
Expand Down
5 changes: 5 additions & 0 deletions src/operator/quantization/quantized_fully_connected.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ void QuantizedFullyConnectedForwardGPU(const nnvm::NodeAttrs& attrs,
mxnet::TShape oshape = out.shape_;
// (m, n) * (k, n).T = (m, k)
// A * B.T = C
if (dshape.ndim() != 2) {
CHECK(param.flatten)
<< "Currently, QuantizedFullyConnected Op only supports flatten=true "
<< "when ishape.ndim()!=2 for GPU.";
}

// row_C = col_C(T) = cublas(col_B * col_A(T)) = cublas(row_B(T), row_A)
// row_C = col_C(T) = cublas(col_B(T) * col_A(T)) = cublas(row_B, row_A)
Expand Down
12 changes: 4 additions & 8 deletions src/operator/subgraph/mkldnn/mkldnn_fc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,6 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
NDArray data = in_data[fullc::kData];
NDArray weight = in_data[fullc::kWeight];
NDArray output = out_data[fullc::kOut];
const mxnet::TShape &ishape = data.shape();
if (mkldnn_param.quantized && ishape.ndim() != 2) {
CHECK(default_param.flatten)
<< "QuantizedFullyConnected only supports flatten=true when ishape.ndim() != 2 for now.";
}

mkldnn::memory::desc out_md = GetMemDesc(output);
MKLDNNFCFlattenData(default_param, out_data[fullc::kOut], &data, &out_md);
Expand Down Expand Up @@ -307,9 +302,10 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs,
if (full_param.mkldnn_param.quantized) {
size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3;

// TODO(ciyong): currently, only uint8 fully_connected is supported,
// int8 fully_connected will be supported after mkldnn v0.18
// TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kUint8);
CHECK(in_types->at(0) == mshadow::kInt8 ||
in_types->at(0) == mshadow::kUint8)
<< "QuantizedFullyConnected only supports int8/uint8 input, while "
<< in_types->at(0) << " is given.";
for (size_t i = 1; i < in_types->size(); ++i) {
if (i < base_num_inputs) {
TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8);
Expand Down
13 changes: 3 additions & 10 deletions tests/python/mkl/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def check_quantize(sym, data_shape, out_type, name='conv',
mod.bind(for_training=False,
data_shapes=[('data', data_shape)],
label_shapes=[('softmax_label', label_shape)])

mod.init_params(mx.init.Normal(0.5))
arg_params, aux_params = mod.get_params()

Expand All @@ -136,12 +135,10 @@ def check_quantize(sym, data_shape, out_type, name='conv',
output.wait_to_read()
ref_out = mod.get_outputs()

# TODO(ciyong), exclude the second fc due to int8 fully_connected is not
# supported before mkldnn 0.18
excluded_sym_names = []
if mx.current_context() == mx.cpu():
if mx.current_context() == mx.cpu() and gluon_forward == True:
excluded_sym_names += ['sg_mkldnn_fully_connected_0']
excluded_sym_names += ['fc_softmax']
excluded_sym_names += ['sg_mkldnn_fully_connected_1']

calib_data = mx.nd.random.uniform(shape=data_shape)
calib_data = NDArrayIter(data=calib_data)
Expand Down Expand Up @@ -193,11 +190,7 @@ def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True
assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-3)

# fp32 to int8
# TODO(ciyong), int8 fully_connected will be supported after mkldnn 0.18
if name == 'fc':
out_type_list = ['uint8', 'auto']
else:
out_type_list = ['uint8', 'int8', 'auto']
out_type_list = ['uint8', 'int8', 'auto']

if check_quantization:
for out_type in out_type_list:
Expand Down
8 changes: 5 additions & 3 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,6 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
if hasMKL == False:
print('skipped testing quantized_fc on cpu since s8u8s32 is only supported by MKL BLAS library')
return
elif qdtype == 'int8' and is_test_for_mkldnn():
print('skipped testing test_quantized_fc for mkldnn cpu int8 since it is not supported yet')
return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing quantized_fc for gpu uint8 since it is not supported yet')
return
Expand Down Expand Up @@ -377,6 +374,11 @@ def maxabs(a, b):
assert cond == 0

for qdtype in ['int8', 'uint8']:
if is_test_for_mkldnn():
check_quantized_fc((32, 512, 2), 100, True, qdtype, flatten=False)
check_quantized_fc((32, 512, 2), 100, False, qdtype, flatten=False)
check_quantized_fc((32, 512, 2, 2), 100, True, qdtype, flatten=False)
check_quantized_fc((32, 512, 2, 2), 100, False, qdtype, flatten=False)
check_quantized_fc((32, 512, 2, 2), 100, True, qdtype)
check_quantized_fc((32, 111, 2, 2), 100, True, qdtype)
check_quantized_fc((32, 512, 2, 2), 100, False, qdtype)
Expand Down

0 comments on commit 3b5276c

Please sign in to comment.