Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add quantized batch_dot #20680

Merged
merged 6 commits into from
Oct 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/mxnet/amp/lists/symbol_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@
'zeros_like',
'_sg_onednn_conv',
'_sg_onednn_fully_connected',
'_sg_onednn_batch_dot',
'broadcast_mul',
'Convolution_v1',
'IdentityAttachKLSparseReg',
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@
'_sg_onednn_fully_connected',
'_sg_onednn_selfatt_qk',
'_sg_onednn_selfatt_valatt',
'_sg_onednn_batch_dot'
])

# Functions that have to be cast to FP32 only for
Expand Down
73 changes: 69 additions & 4 deletions src/operator/nn/dnnl/dnnl_batch_dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,67 @@
namespace mxnet {
namespace op {

enum DotIn { lhs = 0, rhs, lhs_min, lhs_max, rhs_min, rhs_max };
enum DotOut { out = 0, out_min, out_max };

struct DNNLDotParam : public dmlc::Parameter<DNNLDotParam> {
bool transpose_a;
bool transpose_b;
bool quantized;

dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
bool enable_float_output; // min float value calculated from calibration dataset
DMLC_DECLARE_PARAMETER(DNNLDotParam) {
DMLC_DECLARE_FIELD(transpose_a)
.describe("If true then transpose the first input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(transpose_b)
.describe("If true then transpose the second input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(quantized).set_default(false).describe("enable quantization");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe(
"The minimum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized convolution op to calculate primitive scale");
DMLC_DECLARE_FIELD(max_calib_range)
.set_default(dmlc::optional<float>())
.describe(
"The maximum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized convolution op to calculate primitive scale");
DMLC_DECLARE_FIELD(enable_float_output)
.set_default(false)
.describe("Whether to enable float32 output.");
}

bool operator==(const DNNLDotParam& other) const {
return this->transpose_a == other.transpose_a && this->transpose_b == other.transpose_b &&
this->quantized == other.quantized && this->min_calib_range == other.min_calib_range &&
this->max_calib_range == other.max_calib_range;
}
};

using batch_dot_fwd_t = dnnl::matmul;
using batch_dot_fwd_pd_t = dnnl::matmul::primitive_desc;

typedef ParamOpSign<DotParam> BatchDotSignature;
typedef ParamOpSign<DNNLDotParam> BatchDotSignature;

class DNNLBatchDotFwd {
public:
static DNNLBatchDotFwd& GetCached(const DotParam& param,
static DNNLBatchDotFwd& GetCached(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs);

DNNLBatchDotFwd(const DotParam& param,
DNNLBatchDotFwd(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs);

void Execute(const std::vector<NDArray>& inputs,
void Execute(const OpContext& ctx,
const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

Expand All @@ -62,6 +107,26 @@ class DNNLBatchDotFwd {
std::shared_ptr<batch_dot_fwd_pd_t> fwd_pd;
};

template <bool subgraph = true>
void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
DNNLDotParam dnnl_param;
if (!subgraph) {
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
dnnl_param.transpose_a = param.transpose_a;
dnnl_param.transpose_b = param.transpose_b;
dnnl_param.quantized = false;
} else {
dnnl_param = nnvm::get<DNNLDotParam>(attrs.parsed);
}

DNNLBatchDotFwd& fwd = DNNLBatchDotFwd::GetCached(dnnl_param, inputs, outputs);
fwd.Execute(ctx, dnnl_param, inputs, req, outputs);
}

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
Expand Down
157 changes: 123 additions & 34 deletions src/operator/nn/dnnl/dnnl_batch_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,21 @@
#if MXNET_USE_ONEDNN == 1

#include "./dnnl_batch_dot-inl.h"
#include "../../quantization/quantization_utils.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(DNNLDotParam);

bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& output) {
return inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
return inputs[DotIn::lhs].shape().Size() != 0 && inputs[DotIn::rhs].shape().Size() != 0 &&
output.shape().Size() != 0 &&
(inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kBfloat16);
}

void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
DNNLBatchDotFwd& fwd = DNNLBatchDotFwd::GetCached(param, inputs, outputs);
fwd.Execute(inputs, req, outputs);
(inputs[DotIn::lhs].dtype() == mshadow::kFloat32 ||
inputs[DotIn::lhs].dtype() == mshadow::kBfloat16);
}

DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DotParam& param,
DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs) {
using batch_dot_fwd_map = std::unordered_map<BatchDotSignature, DNNLBatchDotFwd, OpHash>;
Expand All @@ -56,9 +50,9 @@ DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DotParam& param,
#endif

BatchDotSignature key(param);
key.AddSign(inputs[0]);
key.AddSign(inputs[1]);
key.AddSign(outputs[0]);
key.AddSign(inputs[DotIn::lhs]);
key.AddSign(inputs[DotIn::rhs]);
key.AddSign(outputs[DotOut::out]);

auto it = fwds.find(key);
if (it == fwds.end()) {
Expand All @@ -68,14 +62,40 @@ DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DotParam& param,
return it->second;
}

DNNLBatchDotFwd::DNNLBatchDotFwd(const DotParam& param,
dnnl::primitive_attr GetQuantizationAttributes(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs) {
dnnl::primitive_attr attr;
float out_scale_ = 1.f;
float lhs_scale_ = GetQuantizeScale(inputs[DotIn::lhs].dtype(),
inputs[DotIn::lhs_min].data().dptr<float>()[0],
inputs[DotIn::lhs_max].data().dptr<float>()[0]);
float rhs_scale_ = GetQuantizeScale(inputs[DotIn::rhs].dtype(),
inputs[DotIn::rhs_min].data().dptr<float>()[0],
inputs[DotIn::rhs_max].data().dptr<float>()[0]);
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
// fused requantize => output is int
out_scale_ = GetQuantizeScale(outputs[DotOut::out].dtype(),
param.min_calib_range.value(),
param.min_calib_range.value()) /
lhs_scale_ / rhs_scale_;
attr.set_output_scales(0, {out_scale_});
} else if (param.enable_float_output) {
out_scale_ = 1.0 / lhs_scale_ / rhs_scale_;
attr.set_output_scales(0, {out_scale_});
}

return attr;
}

DNNLBatchDotFwd::DNNLBatchDotFwd(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs) {
auto shape = inputs[0].shape();
auto ndim = shape.ndim();
auto bigDim = shape[0];
auto lhs_shape = inputs[DotIn::lhs].shape();
auto ndim = lhs_shape.ndim();
auto bigDim = lhs_shape[0];
for (size_t i = 1; i < ndim - 2; ++i) {
bigDim *= shape[i];
bigDim *= lhs_shape[i];
}

auto GetMemoryDesc = [&ndim, &bigDim](const NDArray& tensor, const bool transpose) {
Expand All @@ -91,37 +111,106 @@ DNNLBatchDotFwd::DNNLBatchDotFwd(const DotParam& param,
}
};

dnnl::memory::desc data_md = GetMemoryDesc(inputs[0], param.transpose_a);
dnnl::memory::desc weights_md = GetMemoryDesc(inputs[1], param.transpose_b);
dnnl::memory::desc data_md = GetMemoryDesc(inputs[DotIn::lhs], param.transpose_a);
dnnl::memory::desc weights_md = GetMemoryDesc(inputs[DotIn::rhs], param.transpose_b);
dnnl::memory::desc out_md({bigDim, data_md.dims()[1], weights_md.dims()[2]},
get_dnnl_type(outputs[0].dtype()),
get_dnnl_type(outputs[DotOut::out].dtype()),
dnnl::memory::format_tag::any);
dnnl::matmul::desc fwd_desc(data_md, weights_md, out_md);
fwd_pd = std::make_shared<batch_dot_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
fwd = std::make_shared<batch_dot_fwd_t>(*fwd_pd);
if (param.quantized) {
auto attrs = GetQuantizationAttributes(param, inputs, outputs);
fwd_pd = std::make_shared<batch_dot_fwd_pd_t>(
fwd_desc, attrs, mxnet::CpuEngine::Get()->get_engine());

} else {
fwd_pd = std::make_shared<batch_dot_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
}

fwd = std::make_shared<batch_dot_fwd_t>(*fwd_pd);
}

void DNNLBatchDotFwd::Execute(const std::vector<NDArray>& inputs,
void DNNLBatchDotFwd::Execute(const OpContext& ctx,
const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
auto engine = mxnet::CpuEngine::Get()->get_engine();
auto data =
dnnl::memory(fwd_pd->src_desc(), engine, reinterpret_cast<void*>(inputs[0].data().dptr_));
auto weights =
dnnl::memory(fwd_pd->weights_desc(), engine, reinterpret_cast<void*>(inputs[1].data().dptr_));
dnnl_output_t out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]);
auto lhs = inputs[DotIn::lhs];
auto rhs = inputs[DotIn::rhs];
// Created primitive descriptor assumes that both inputs are in default format
if (lhs.IsDNNLData())
lhs = lhs.Reorder2Default();
if (rhs.IsDNNLData())
rhs = rhs.Reorder2Default();

auto lhs_mem =
dnnl::memory(fwd_pd->src_desc(), engine, reinterpret_cast<void*>(lhs.data().dptr_));
auto rhs_mem =
dnnl::memory(fwd_pd->weights_desc(), engine, reinterpret_cast<void*>(rhs.data().dptr_));
dnnl_output_t out_mem = CreateDNNLMem(
outputs[DotOut::out], fwd_pd->dst_desc(), req[DotOut::out], &inputs[DotIn::lhs]);

dnnl_args_map_t args = {
{DNNL_ARG_SRC, data},
{DNNL_ARG_WEIGHTS, weights},
{DNNL_ARG_SRC, lhs_mem},
{DNNL_ARG_WEIGHTS, rhs_mem},
{DNNL_ARG_DST, *out_mem.second},
};

DNNLStream::Get()->RegisterPrimArgs(*fwd, args);
CommitOutput(outputs[0], out_mem);
DNNLStream::Get()->Submit();

if (param.quantized && !param.enable_float_output) {
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
float min_output;
float max_output;
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
min_output = param.min_calib_range.value();
max_output = param.max_calib_range.value();
} else {
if (inputs[DotIn::lhs].dtype() == mshadow::kInt8) {
mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
s,
1,
&min_output,
&max_output,
inputs[DotIn::rhs_min].data().dptr<float>(),
inputs[DotIn::rhs_max].data().dptr<float>(),
inputs[DotIn::lhs_min].data().dptr<float>(),
inputs[DotIn::lhs_max].data().dptr<float>());
} else {
mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, cpu>::Launch(
s,
1,
&min_output,
&max_output,
inputs[DotIn::rhs_min].data().dptr<float>(),
inputs[DotIn::rhs_max].data().dptr<float>(),
inputs[DotIn::lhs_min].data().dptr<float>(),
inputs[DotIn::lhs_max].data().dptr<float>());
}
}

float* min_output_ptr = outputs[DotOut::out_min].data().dptr<float>();
float* max_output_ptr = outputs[DotOut::out_max].data().dptr<float>();
*min_output_ptr = min_output;
*max_output_ptr = max_output;
}
}

} // namespace op
} // namespace mxnet

namespace std {
template <>
struct hash<mxnet::op::DNNLDotParam> {
size_t operator()(const mxnet::op::DNNLDotParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.transpose_a);
ret = dmlc::HashCombine(ret, val.transpose_b);
ret = dmlc::HashCombine(ret, val.quantized);
return ret;
}
};
} // namespace std
#endif // MXNET_USE_ONEDNN == 1
1 change: 1 addition & 0 deletions src/operator/nn/dnnl/dnnl_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ void DNNLConcatBackward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs);

/* For batch dot */
template <bool subgraph>
void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down
Loading