Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add quantized batch_dot (#20680)
Browse files Browse the repository at this point in the history
* Add quantized batch_dot

* Fix sanity

* Fix names for post quantize fuse

* Fixes

* update amp list

* fix sanity
  • Loading branch information
bgawrych committed Oct 30, 2021
1 parent 79e1753 commit fb1d395
Show file tree
Hide file tree
Showing 13 changed files with 551 additions and 78 deletions.
1 change: 1 addition & 0 deletions python/mxnet/amp/lists/symbol_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@
'zeros_like',
'_sg_onednn_conv',
'_sg_onednn_fully_connected',
'_sg_onednn_batch_dot',
'broadcast_mul',
'Convolution_v1',
'IdentityAttachKLSparseReg',
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@
'_sg_onednn_fully_connected',
'_sg_onednn_selfatt_qk',
'_sg_onednn_selfatt_valatt',
'_sg_onednn_batch_dot'
])

# Functions that have to be cast to FP32 only for
Expand Down
73 changes: 69 additions & 4 deletions src/operator/nn/dnnl/dnnl_batch_dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,67 @@
namespace mxnet {
namespace op {

enum DotIn { lhs = 0, rhs, lhs_min, lhs_max, rhs_min, rhs_max };
enum DotOut { out = 0, out_min, out_max };

struct DNNLDotParam : public dmlc::Parameter<DNNLDotParam> {
bool transpose_a;
bool transpose_b;
bool quantized;

dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
bool enable_float_output; // min float value calculated from calibration dataset
DMLC_DECLARE_PARAMETER(DNNLDotParam) {
DMLC_DECLARE_FIELD(transpose_a)
.describe("If true then transpose the first input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(transpose_b)
.describe("If true then transpose the second input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(quantized).set_default(false).describe("enable quantization");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe(
"The minimum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized convolution op to calculate primitive scale");
DMLC_DECLARE_FIELD(max_calib_range)
.set_default(dmlc::optional<float>())
.describe(
"The maximum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized convolution op to calculate primitive scale");
DMLC_DECLARE_FIELD(enable_float_output)
.set_default(false)
.describe("Whether to enable float32 output.");
}

bool operator==(const DNNLDotParam& other) const {
return this->transpose_a == other.transpose_a && this->transpose_b == other.transpose_b &&
this->quantized == other.quantized && this->min_calib_range == other.min_calib_range &&
this->max_calib_range == other.max_calib_range;
}
};

using batch_dot_fwd_t = dnnl::matmul;
using batch_dot_fwd_pd_t = dnnl::matmul::primitive_desc;

typedef ParamOpSign<DotParam> BatchDotSignature;
typedef ParamOpSign<DNNLDotParam> BatchDotSignature;

class DNNLBatchDotFwd {
public:
static DNNLBatchDotFwd& GetCached(const DotParam& param,
static DNNLBatchDotFwd& GetCached(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs);

DNNLBatchDotFwd(const DotParam& param,
DNNLBatchDotFwd(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs);

void Execute(const std::vector<NDArray>& inputs,
void Execute(const OpContext& ctx,
const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

Expand All @@ -62,6 +107,26 @@ class DNNLBatchDotFwd {
std::shared_ptr<batch_dot_fwd_pd_t> fwd_pd;
};

template <bool subgraph = true>
void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
DNNLDotParam dnnl_param;
if (!subgraph) {
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
dnnl_param.transpose_a = param.transpose_a;
dnnl_param.transpose_b = param.transpose_b;
dnnl_param.quantized = false;
} else {
dnnl_param = nnvm::get<DNNLDotParam>(attrs.parsed);
}

DNNLBatchDotFwd& fwd = DNNLBatchDotFwd::GetCached(dnnl_param, inputs, outputs);
fwd.Execute(ctx, dnnl_param, inputs, req, outputs);
}

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
Expand Down
157 changes: 123 additions & 34 deletions src/operator/nn/dnnl/dnnl_batch_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,21 @@
#if MXNET_USE_ONEDNN == 1

#include "./dnnl_batch_dot-inl.h"
#include "../../quantization/quantization_utils.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(DNNLDotParam);

bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& output) {
return inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
return inputs[DotIn::lhs].shape().Size() != 0 && inputs[DotIn::rhs].shape().Size() != 0 &&
output.shape().Size() != 0 &&
(inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kBfloat16);
}

void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
DNNLBatchDotFwd& fwd = DNNLBatchDotFwd::GetCached(param, inputs, outputs);
fwd.Execute(inputs, req, outputs);
(inputs[DotIn::lhs].dtype() == mshadow::kFloat32 ||
inputs[DotIn::lhs].dtype() == mshadow::kBfloat16);
}

DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DotParam& param,
DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs) {
using batch_dot_fwd_map = std::unordered_map<BatchDotSignature, DNNLBatchDotFwd, OpHash>;
Expand All @@ -56,9 +50,9 @@ DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DotParam& param,
#endif

BatchDotSignature key(param);
key.AddSign(inputs[0]);
key.AddSign(inputs[1]);
key.AddSign(outputs[0]);
key.AddSign(inputs[DotIn::lhs]);
key.AddSign(inputs[DotIn::rhs]);
key.AddSign(outputs[DotOut::out]);

auto it = fwds.find(key);
if (it == fwds.end()) {
Expand All @@ -68,14 +62,40 @@ DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DotParam& param,
return it->second;
}

DNNLBatchDotFwd::DNNLBatchDotFwd(const DotParam& param,
dnnl::primitive_attr GetQuantizationAttributes(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs) {
dnnl::primitive_attr attr;
float out_scale_ = 1.f;
float lhs_scale_ = GetQuantizeScale(inputs[DotIn::lhs].dtype(),
inputs[DotIn::lhs_min].data().dptr<float>()[0],
inputs[DotIn::lhs_max].data().dptr<float>()[0]);
float rhs_scale_ = GetQuantizeScale(inputs[DotIn::rhs].dtype(),
inputs[DotIn::rhs_min].data().dptr<float>()[0],
inputs[DotIn::rhs_max].data().dptr<float>()[0]);
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
// fused requantize => output is int
out_scale_ = GetQuantizeScale(outputs[DotOut::out].dtype(),
param.min_calib_range.value(),
param.min_calib_range.value()) /
lhs_scale_ / rhs_scale_;
attr.set_output_scales(0, {out_scale_});
} else if (param.enable_float_output) {
out_scale_ = 1.0 / lhs_scale_ / rhs_scale_;
attr.set_output_scales(0, {out_scale_});
}

return attr;
}

DNNLBatchDotFwd::DNNLBatchDotFwd(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs) {
auto shape = inputs[0].shape();
auto ndim = shape.ndim();
auto bigDim = shape[0];
auto lhs_shape = inputs[DotIn::lhs].shape();
auto ndim = lhs_shape.ndim();
auto bigDim = lhs_shape[0];
for (size_t i = 1; i < ndim - 2; ++i) {
bigDim *= shape[i];
bigDim *= lhs_shape[i];
}

auto GetMemoryDesc = [&ndim, &bigDim](const NDArray& tensor, const bool transpose) {
Expand All @@ -91,37 +111,106 @@ DNNLBatchDotFwd::DNNLBatchDotFwd(const DotParam& param,
}
};

dnnl::memory::desc data_md = GetMemoryDesc(inputs[0], param.transpose_a);
dnnl::memory::desc weights_md = GetMemoryDesc(inputs[1], param.transpose_b);
dnnl::memory::desc data_md = GetMemoryDesc(inputs[DotIn::lhs], param.transpose_a);
dnnl::memory::desc weights_md = GetMemoryDesc(inputs[DotIn::rhs], param.transpose_b);
dnnl::memory::desc out_md({bigDim, data_md.dims()[1], weights_md.dims()[2]},
get_dnnl_type(outputs[0].dtype()),
get_dnnl_type(outputs[DotOut::out].dtype()),
dnnl::memory::format_tag::any);
dnnl::matmul::desc fwd_desc(data_md, weights_md, out_md);
fwd_pd = std::make_shared<batch_dot_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
fwd = std::make_shared<batch_dot_fwd_t>(*fwd_pd);
if (param.quantized) {
auto attrs = GetQuantizationAttributes(param, inputs, outputs);
fwd_pd = std::make_shared<batch_dot_fwd_pd_t>(
fwd_desc, attrs, mxnet::CpuEngine::Get()->get_engine());

} else {
fwd_pd = std::make_shared<batch_dot_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
}

fwd = std::make_shared<batch_dot_fwd_t>(*fwd_pd);
}

void DNNLBatchDotFwd::Execute(const std::vector<NDArray>& inputs,
void DNNLBatchDotFwd::Execute(const OpContext& ctx,
const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
auto engine = mxnet::CpuEngine::Get()->get_engine();
auto data =
dnnl::memory(fwd_pd->src_desc(), engine, reinterpret_cast<void*>(inputs[0].data().dptr_));
auto weights =
dnnl::memory(fwd_pd->weights_desc(), engine, reinterpret_cast<void*>(inputs[1].data().dptr_));
dnnl_output_t out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]);
auto lhs = inputs[DotIn::lhs];
auto rhs = inputs[DotIn::rhs];
// Created primitive descriptor assumes that both inputs are in default format
if (lhs.IsDNNLData())
lhs = lhs.Reorder2Default();
if (rhs.IsDNNLData())
rhs = rhs.Reorder2Default();

auto lhs_mem =
dnnl::memory(fwd_pd->src_desc(), engine, reinterpret_cast<void*>(lhs.data().dptr_));
auto rhs_mem =
dnnl::memory(fwd_pd->weights_desc(), engine, reinterpret_cast<void*>(rhs.data().dptr_));
dnnl_output_t out_mem = CreateDNNLMem(
outputs[DotOut::out], fwd_pd->dst_desc(), req[DotOut::out], &inputs[DotIn::lhs]);

dnnl_args_map_t args = {
{DNNL_ARG_SRC, data},
{DNNL_ARG_WEIGHTS, weights},
{DNNL_ARG_SRC, lhs_mem},
{DNNL_ARG_WEIGHTS, rhs_mem},
{DNNL_ARG_DST, *out_mem.second},
};

DNNLStream::Get()->RegisterPrimArgs(*fwd, args);
CommitOutput(outputs[0], out_mem);
DNNLStream::Get()->Submit();

if (param.quantized && !param.enable_float_output) {
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
float min_output;
float max_output;
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
min_output = param.min_calib_range.value();
max_output = param.max_calib_range.value();
} else {
if (inputs[DotIn::lhs].dtype() == mshadow::kInt8) {
mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
s,
1,
&min_output,
&max_output,
inputs[DotIn::rhs_min].data().dptr<float>(),
inputs[DotIn::rhs_max].data().dptr<float>(),
inputs[DotIn::lhs_min].data().dptr<float>(),
inputs[DotIn::lhs_max].data().dptr<float>());
} else {
mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, cpu>::Launch(
s,
1,
&min_output,
&max_output,
inputs[DotIn::rhs_min].data().dptr<float>(),
inputs[DotIn::rhs_max].data().dptr<float>(),
inputs[DotIn::lhs_min].data().dptr<float>(),
inputs[DotIn::lhs_max].data().dptr<float>());
}
}

float* min_output_ptr = outputs[DotOut::out_min].data().dptr<float>();
float* max_output_ptr = outputs[DotOut::out_max].data().dptr<float>();
*min_output_ptr = min_output;
*max_output_ptr = max_output;
}
}

} // namespace op
} // namespace mxnet

namespace std {
template <>
struct hash<mxnet::op::DNNLDotParam> {
size_t operator()(const mxnet::op::DNNLDotParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.transpose_a);
ret = dmlc::HashCombine(ret, val.transpose_b);
ret = dmlc::HashCombine(ret, val.quantized);
return ret;
}
};
} // namespace std
#endif // MXNET_USE_ONEDNN == 1
1 change: 1 addition & 0 deletions src/operator/nn/dnnl/dnnl_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ void DNNLConcatBackward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs);

/* For batch dot */
template <bool subgraph>
void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down
Loading

0 comments on commit fb1d395

Please sign in to comment.