Skip to content

Commit

Permalink
Support Quantized Fully Connected by INT8 GEMM (apache#12922)
Browse files Browse the repository at this point in the history
* add quantized fully connect support

* disable qfc cpu case since s8u8s32 is only supported by MKL BLAS library

* retrigger to ci testing

* move implementation to cc file and add  STORAGE_TYPE_ASSIGN_CHECK

* fix typo bug

* retrigger the ci test

* fix typo bug

* retrigger ci

* retrigger the ci test

* retrigger the ci

* retrigger the ci test

* retrigger ci test

* fix indent issue

* retrigger the ci

* retrigger the ci test

* add verbose message

* update log message

* using range for loop

* using for auto range

* enable MKL BLAS ci test

* fix typo issue

* use TYPE_ASSIGN_CHECK

* retrigger the ci
  • Loading branch information
Hao Li authored and Ubuntu committed Dec 18, 2018
1 parent c41d873 commit 8a00f18
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 8 deletions.
159 changes: 158 additions & 1 deletion src/operator/quantization/quantized_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@
* \brief
* \author Ziheng Jiang, Jun Wu
*/
#include <vector>
#include "quantization_utils.h"
#include "../nn/fully_connected-inl.h"

namespace mxnet {
namespace op {

namespace quantized_fc {
enum QuantizedfcOpResource {kTempSpace};
}

bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
Expand Down Expand Up @@ -79,6 +85,151 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
return true;
}

bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
*dispatch_mode = DispatchMode::kFCompute;
if (dev_mask == mshadow::cpu::kDevMask) {
*dispatch_mode = DispatchMode::kFComputeEx;
}

for (auto &v : *out_attrs) {
v = kDefaultStorage;
if (common::stype_string(v).compare("unknown") == 0) {
return false;
}
}

for (auto &v : *in_attrs) {
v = kDefaultStorage;
if (common::stype_string(v).compare("unknown") == 0) {
return false;
}
}
return true;
}

struct QuantizedSumInitKernelWithBias {
// init sum data with bias for matrix b (n)
MSHADOW_XINLINE static void Map(int i, int32_t *out,
const int8_t *bias, const float *min_out,
const float *max_out, const float *min_bias,
const float *max_bias) {
typedef int32_t T1;
typedef int8_t T2;
using mshadow::red::limits::MinValue;
using mshadow::red::limits::MaxValue;
float float_for_one_out_quant =
MaxAbs(*min_out, *max_out) / static_cast<double>(MaxValue<T1>());
float float_for_one_bias_quant =
MaxAbs(*min_bias, *max_bias) / static_cast<double>(MaxValue<T2>());
if (float_for_one_out_quant != 0) {
out[i] = bias[i] * float_for_one_bias_quant /
float_for_one_out_quant;
} else {
LOG(INFO) << "float_for_one_out_quant is 0,"
<< " need to check the why MaxAbs(*min_out, *max_out) of out_data is 0!";
out[i] = 0;
}
}
};


template<typename SrcType>
void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
#if MSHADOW_USE_MKL == 1
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
using namespace mshadow;
using namespace mxnet_op;
size_t num_inputs = param.no_bias ? 2 : 3;
CHECK_EQ(in_data.size(), num_inputs * 3);
CHECK_EQ(out_data.size(), 3U);
const NDArray& data = in_data[0];
const NDArray& weight = in_data[1];
const NDArray& out = out_data[0];
TShape dshape = data.shape();
TShape wshape = weight.shape();
TShape oshape = out.shape();
auto output_temp = out.data().dptr<int32_t>();
auto weight_temp = weight.data().dptr<SrcType>();
auto data_temp = data.data().dptr<SrcType>();
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
const float alpha = 1.0f;
const float beta = 1.0f;
const CBLAS_OFFSET offsetc = CblasFixOffset;
const MKL_INT8 oa = 0;
const MKL_INT8 ob = 0;
MKL_INT32 oc = 0;
const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim());
Stream<cpu> *s = ctx.get_stream<cpu>();
// cblas_gemm_s8u8s32 required first matrix must be uint8
// shift data from int8(from -128 to 127) to uint8 (from 0 to 255)
int shift = 128;
Tensor<cpu, 1, uint8_t> shiftdata =
ctx.requested[quantized_fc::kTempSpace].get_space_typed<cpu, 1, uint8_t>(
Shape1(m * k), s);
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < m * k; ++i) {
shiftdata.dptr_[i] = data_temp[i] + shift;
}

Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(),
in_data[num_inputs].data().dptr<float>(), in_data[num_inputs+1].data().dptr<float>(),
in_data[num_inputs+2].data().dptr<float>(), in_data[num_inputs+3].data().dptr<float>());
if (!param.no_bias) {
const NDArray& bias = in_data[2];
Kernel<QuantizedSumInitKernelWithBias, cpu>::Launch(s, n, out.data().dptr<int32_t>(),
bias.data().dptr<int8_t>(), out_data[1].data().dptr<float>(),
out_data[2].data().dptr<float>(), in_data[7].data().dptr<float>(),
in_data[8].data().dptr<float>());
} else {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < m * n; ++i) {
output_temp[i] = 0;
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < n; ++i) {
for (int j = 0; j < k; ++j) {
output_temp[i] -= shift * weight_temp[i * k + j];
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = n; i < m * n; ++i) {
output_temp[i] = output_temp[i % n];
}
cblas_gemm_s8u8s32(CblasRowMajor,
CblasNoTrans,
CblasTrans,
offsetc,
m,
n,
k,
alpha,
shiftdata.dptr_,
k,
oa,
weight.data().dptr<SrcType>(),
k,
ob,
beta,
out.data().dptr<int32_t>(),
n,
&oc);
#else
LOG(FATAL) << "Quantized fully connected operator relies on cblas_gemm_s8u8s32"
<< " which is only supported by MKL BLAS."
<< " Please build MXNet with USE_BLAS=mkl to leverage this operator.";
#endif
}

NNVM_REGISTER_OP(_contrib_quantized_fully_connected)
.describe(R"code(Fully Connected operator for input, weight and bias data type of int8,
and accumulates in type int32 for the output. For each argument, two more arguments of type
Expand Down Expand Up @@ -112,7 +263,14 @@ and max thresholds representing the threholds for quantizing the float32 output
})
.set_attr<nnvm::FInferShape>("FInferShape", QuantizedFullyConnectedShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizedFullyConnectedStorageType)
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
.set_attr<FComputeEx>("FComputeEx<cpu>",
QuantizedFullyConnectedForward<int8_t>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.add_argument("data", "NDArray-or-Symbol", "Input data.")
.add_argument("weight", "NDArray-or-Symbol", "weight.")
.add_argument("bias", "NDArray-or-Symbol", "bias.")
Expand All @@ -135,6 +293,5 @@ NNVM_REGISTER_OP(FullyConnected)
}
return node;
});

} // namespace op
} // namespace mxnet
26 changes: 19 additions & 7 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from mxnet.module import Module
from mxnet.io import NDArrayIter
import unittest
import operator

def is_test_for_gpu():
return mx.current_context().device_type == 'gpu'
Expand Down Expand Up @@ -278,8 +279,15 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p
def test_quantized_fc():
def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
if mx.current_context().device_type != 'gpu':
print('skipped testing quantized_fc on cpu since it is not supported yet')
return
hasMKL = False;
for key in os.environ.keys():
if operator.eq(key, "BUILD_TAG"):
if os.environ['BUILD_TAG'].find("MKL") != -1:
hasMKL = True
break
if hasMKL == False:
print('skipped testing quantized_fc on cpu since s8u8s32 is only supported by MKL BLAS library')
return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing quantized_fc for gpu uint8 since it is not supported yet')
return
Expand All @@ -291,16 +299,16 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
if qdtype == 'uint8':
data_low = 0.0
data_high = 127.0
data_high = 63.0
else:
data_low = -127.0
data_high = 127.0
data_low = -63.0
data_high = 63.0
fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
shape=data_shape).astype('int32')
fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
shape=arg_shapes[1]).astype('int32')
if not no_bias:
fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
shape=arg_shapes[2]).astype('int32')
output = fc_fp32_exe.forward()[0]

Expand Down Expand Up @@ -343,6 +351,10 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
check_quantized_fc((32, 111, 2, 2), 100, True, qdtype)
check_quantized_fc((32, 512, 2, 2), 100, False, qdtype)
check_quantized_fc((32, 111, 2, 2), 100, False, qdtype)
check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype)
check_quantized_fc((256, 111, 2, 2), 800, False, qdtype)
check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype)
check_quantized_fc((256, 111, 2, 2), 800, True, qdtype)

@with_seed()
def test_quantized_flatten():
Expand Down

0 comments on commit 8a00f18

Please sign in to comment.