Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support Quantized Fully Connected by INT8 GEMM #12922

Merged
merged 24 commits into from
Dec 15, 2018
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 152 additions & 1 deletion src/operator/quantization/quantized_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@
* \brief
* \author Ziheng Jiang, Jun Wu
*/
#include <vector>
#include "quantization_utils.h"
#include "../nn/fully_connected-inl.h"

namespace mxnet {
namespace op {

namespace quantized_fc {
enum QuantizedfcOpResource {kTempSpace};
}

bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
Expand Down Expand Up @@ -79,6 +85,145 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
return true;
}

bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
*dispatch_mode = DispatchMode::kFCompute;
if (dev_mask == mshadow::cpu::kDevMask) {
*dispatch_mode = DispatchMode::kFComputeEx;
}
for (size_t i = 0; i < out_attrs->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, kDefaultStorage);
if (common::stype_string((*out_attrs)[i]).compare("unknown") == 0) {
return false;
}
}
for (size_t i = 0; i < in_attrs->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i, kDefaultStorage);
if (common::stype_string((*in_attrs)[i]).compare("unknown") == 0) {
return false;
}
}
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if in_attrs has unknown storage types? You need to

  1. Check and assign stype to in_attrs as well.
  2. return false if any stype is unknown in in_attrs and out_attrs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider using range for loops for readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @larroy meant to use:

for (auto &v : in_attrs) {
  // ...
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}

struct QuantizedSumInitKernelWithBias {
// init sum data with bias for matrix b (n)
MSHADOW_XINLINE static void Map(int i, int32_t *out,
const int8_t *bias, const float *min_out,
const float *max_out, const float *min_bias,
const float *max_bias) {
typedef int32_t T1;
typedef int8_t T2;
using mshadow::red::limits::MinValue;
using mshadow::red::limits::MaxValue;
float float_for_one_out_quant =
MaxAbs(*min_out, *max_out) / static_cast<double>(MaxValue<T1>());
float float_for_one_bias_quant =
MaxAbs(*min_bias, *max_bias) / static_cast<double>(MaxValue<T2>());
if (float_for_one_out_quant != 0) {
out[i] = bias[i] * float_for_one_bias_quant /
float_for_one_out_quant;
} else {
LOG(INFO) << "WARNING: QuantizedBiasAddKernel float_for_one_out_quant is 0 !";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this info more verbose and add more details?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

out[i] = 0;
}
}
};

template<typename SrcType>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need a blank line before this line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix indent.

const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
#if MSHADOW_USE_MKL == 1
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
using namespace mshadow;
using namespace mxnet_op;
size_t num_inputs = param.no_bias ? 2 : 3;
CHECK_EQ(in_data.size(), num_inputs * 3);
CHECK_EQ(out_data.size(), 3U);
const NDArray& data = in_data[0];
const NDArray& weight = in_data[1];
const NDArray& out = out_data[0];
TShape dshape = data.shape();
TShape wshape = weight.shape();
TShape oshape = out.shape();
auto output_temp = out.data().dptr<int32_t>();
auto weight_temp = weight.data().dptr<SrcType>();
auto data_temp = data.data().dptr<SrcType>();
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
KellenSunderland marked this conversation as resolved.
Show resolved Hide resolved
const float alpha = 1.0f;
const float beta = 1.0f;
const CBLAS_OFFSET offsetc = CblasFixOffset;
const MKL_INT8 oa = 0;
const MKL_INT8 ob = 0;
MKL_INT32 oc = 0;
const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim());
Stream<cpu> *s = ctx.get_stream<cpu>();
// cblas_gemm_s8u8s32 required first matrix must be uint8
// shift data from int8(from -128 to 127) to uint8 (from 0 to 255)
int shift = 128;
Tensor<cpu, 1, uint8_t> shiftdata =
ctx.requested[quantized_fc::kTempSpace].get_space_typed<cpu, 1, uint8_t>(
Shape1(m * k), s);
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < m * k; ++i) {
shiftdata.dptr_[i] = data_temp[i] + shift;
}

Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(),
in_data[num_inputs].data().dptr<float>(), in_data[num_inputs+1].data().dptr<float>(),
in_data[num_inputs+2].data().dptr<float>(), in_data[num_inputs+3].data().dptr<float>());
if (!param.no_bias) {
const NDArray& bias = in_data[2];
Kernel<QuantizedSumInitKernelWithBias, cpu>::Launch(s, n, out.data().dptr<int32_t>(),
bias.data().dptr<int8_t>(), out_data[1].data().dptr<float>(),
out_data[2].data().dptr<float>(), in_data[7].data().dptr<float>(),
in_data[8].data().dptr<float>());
} else {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < m * n; ++i) {
output_temp[i] = 0;
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < n; ++i) {
for (int j = 0; j < k; ++j) {
output_temp[i] -= shift * weight_temp[i * k + j];
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = n; i < m * n; ++i) {
output_temp[i] = output_temp[i % n];
}
cblas_gemm_s8u8s32(CblasRowMajor,
CblasNoTrans,
CblasTrans,
offsetc,
m,
n,
k,
alpha,
shiftdata.dptr_,
k,
oa,
weight.data().dptr<SrcType>(),
k,
ob,
beta,
out.data().dptr<int32_t>(),
n,
&oc);
#else
LOG(FATAL) << "s8u8s32 is only supported by MKL BLAS library";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this error message be made little bit more verbose for users? Like mentioning Quantized INT8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

#endif
}

NNVM_REGISTER_OP(_contrib_quantized_fully_connected)
.describe(R"code(Fully Connected operator for input, weight and bias data type of int8,
and accumulates in type int32 for the output. For each argument, two more arguments of type
Expand Down Expand Up @@ -112,7 +257,14 @@ and max thresholds representing the threholds for quantizing the float32 output
})
.set_attr<nnvm::FInferShape>("FInferShape", QuantizedFullyConnectedShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizedFullyConnectedStorageType)
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
.set_attr<FComputeEx>("FComputeEx<cpu>",
QuantizedFullyConnectedForward<int8_t>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.add_argument("data", "NDArray-or-Symbol", "Input data.")
.add_argument("weight", "NDArray-or-Symbol", "weight.")
.add_argument("bias", "NDArray-or-Symbol", "bias.")
Expand All @@ -135,6 +287,5 @@ NNVM_REGISTER_OP(FullyConnected)
}
return node;
});

} // namespace op
} // namespace mxnet
16 changes: 10 additions & 6 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p
def test_quantized_fc():
def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
if mx.current_context().device_type != 'gpu':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to run this test on CPU in CI. Could we test to see if 'MKL' is in the env var 'BUILD_TAG' and run the test if it is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KellenSunderland good suggestion! Currently, the CI doesn't include Intel MKL library as BLAS library and @azai91 is working on adding it so that we can have a better coverage, such as batch_gemm, quantization FC, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengzhao-intel Oh sorry, didn't realize that was the case. If the tests won't pass without full mkl installed and it's not there let's add this in a later PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengzhao-intel do you mean the full MKL? We already use MKLML on CI.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lebeg yes, I mean full MKL. The MKLML doesn't have the INT8 GEMM now :)

print('skipped testing quantized_fc on cpu since it is not supported yet')
print('skipped testing quantized_fc on cpu since s8u8s32 is only supported by MKL BLAS library')
return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing quantized_fc for gpu uint8 since it is not supported yet')
Expand All @@ -283,16 +283,16 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
if qdtype == 'uint8':
data_low = 0.0
data_high = 127.0
data_high = 63.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason of changing this?

Copy link
Contributor Author

@lihaofd lihaofd Oct 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change data range from (-127,127) to (-63, 63) to avoid potential overflow when using igemm in some hardware platform

else:
data_low = -127.0
data_high = 127.0
data_low = -63.0
data_high = 63.0
fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
shape=data_shape).astype('int32')
fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
shape=arg_shapes[1]).astype('int32')
if not no_bias:
fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
shape=arg_shapes[2]).astype('int32')
output = fc_fp32_exe.forward()[0]

Expand Down Expand Up @@ -335,6 +335,10 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
check_quantized_fc((32, 111, 2, 2), 100, True, qdtype)
check_quantized_fc((32, 512, 2, 2), 100, False, qdtype)
check_quantized_fc((32, 111, 2, 2), 100, False, qdtype)
check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype)
check_quantized_fc((256, 111, 2, 2), 800, False, qdtype)
check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype)
check_quantized_fc((256, 111, 2, 2), 800, True, qdtype)

@with_seed()
def test_quantized_flatten():
Expand Down