Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix sanity
Browse files Browse the repository at this point in the history
  • Loading branch information
bgawrych committed Oct 19, 2021
1 parent 77df20e commit f0b5603
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 28 deletions.
15 changes: 7 additions & 8 deletions src/operator/nn/dnnl/dnnl_batch_dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,22 @@ namespace op {
enum DotIn { lhs = 0, rhs, lhs_min, lhs_max, rhs_min, rhs_max };
enum DotOut { out = 0, out_min, out_max };


struct DNNLDotParam : public dmlc::Parameter<DNNLDotParam> {
bool transpose_a;
bool transpose_b;
bool quantized;

dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
bool enable_float_output; // min float value calculated from calibration dataset
bool enable_float_output; // min float value calculated from calibration dataset
DMLC_DECLARE_PARAMETER(DNNLDotParam) {
DMLC_DECLARE_FIELD(transpose_a)
.describe("If true then transpose the first input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(transpose_b)
.describe("If true then transpose the second input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(quantized).set_default(false).describe("enable quantization");
DMLC_DECLARE_FIELD(quantized).set_default(false).describe("enable quantization");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe(
Expand Down Expand Up @@ -110,13 +109,13 @@ class DNNLBatchDotFwd {

template <bool subgraph = true>
void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
DNNLDotParam dnnl_param;
if (!subgraph) {
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
dnnl_param.transpose_a = param.transpose_a;
dnnl_param.transpose_b = param.transpose_b;
dnnl_param.quantized = false;
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/dnnl/dnnl_batch_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DNNLDotParam& param,
}

dnnl::primitive_attr GetQuantizationAttributes(const DNNLDotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs) {
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs) {
dnnl::primitive_attr attr;
float out_scale_ = 1.f;
float lhs_scale_ = GetQuantizeScale(inputs[DotIn::lhs].dtype(),
Expand Down
26 changes: 13 additions & 13 deletions src/operator/subgraph/dnnl/dnnl_batch_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ namespace mxnet {
namespace op {

bool DNNLBatchDotShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shapes,
mxnet::ShapeVector* out_shapes) {
mxnet::ShapeVector* in_shapes,
mxnet::ShapeVector* out_shapes) {
const DNNLDotParam& param = nnvm::get<DNNLDotParam>(attrs.parsed);
mxnet::ShapeVector base_in_shapes;
mxnet::ShapeVector base_out_shapes;
Expand Down Expand Up @@ -70,9 +70,9 @@ bool DNNLBatchDotShape(const nnvm::NodeAttrs& attrs,
}

bool DNNLBatchDotType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_types,
std::vector<int>* out_types) {
const DNNLDotParam& param = nnvm::get<DNNLDotParam>(attrs.parsed);
std::vector<int>* in_types,
std::vector<int>* out_types) {
const DNNLDotParam& param = nnvm::get<DNNLDotParam>(attrs.parsed);
const size_t base_num_inputs = 2;
if (param.quantized) {
CHECK(in_types->at(DotIn::lhs) == mshadow::kInt8 || in_types->at(DotIn::lhs) == mshadow::kUint8)
Expand Down Expand Up @@ -107,10 +107,10 @@ bool DNNLBatchDotType(const nnvm::NodeAttrs& attrs,
}

inline static bool DNNLBatchDotStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
}

Expand Down Expand Up @@ -157,10 +157,10 @@ NNVM_REGISTER_OP(_sg_dnnl_batch_dot)
[](const NodeAttrs& attrs) { return QuantizeType::kMust; })
.set_attr<FQuantizedOp>("FQuantizedOp",
[](const NodeAttrs& attrs) {
nnvm::ObjectPtr node = nnvm::Node::Create();
node->attrs.op = Op::Get("_sg_dnnl_batch_dot");
node->attrs.name = "quantized_" + attrs.name;
node->attrs.dict = attrs.dict;
nnvm::ObjectPtr node = nnvm::Node::Create();
node->attrs.op = Op::Get("_sg_dnnl_batch_dot");
node->attrs.name = "quantized_" + attrs.name;
node->attrs.dict = attrs.dict;
node->attrs.dict["quantized"] = "True";

if (node->op()->attr_parser != nullptr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ class SgDNNLMatmulPostQuantizeProperty : public SubgraphProperty {
}

SubgraphSelectorPtr CreateSubgraphSelector() const override {
auto selector = std::make_shared<SgDNNLMatmulPostQuantizeSelector>(disable_fuse_all,
disable_float_output);
auto selector =
std::make_shared<SgDNNLMatmulPostQuantizeSelector>(disable_fuse_all, disable_float_output);
return selector;
}

Expand Down
6 changes: 3 additions & 3 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1516,15 +1516,15 @@ void BatchDotForward_(const nnvm::NodeAttrs& attrs,
});
}

template<typename ParamType>
template <typename ParamType>
inline bool BatchDotShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
mxnet::TShape& lshape = (*in_attrs)[0];
mxnet::TShape& rshape = (*in_attrs)[1];
mxnet::TShape& lshape = (*in_attrs)[0];
mxnet::TShape& rshape = (*in_attrs)[1];
// return false if lhs and rhs both have fully unknown shape
if (!ndim_is_known(lshape) || !ndim_is_known(rshape))
return false;
Expand Down

0 comments on commit f0b5603

Please sign in to comment.