Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change to use mxnet::Tuple and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ciyongch committed Mar 2, 2019
1 parent 73f95dc commit b7d8324
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions src/operator/subgraph/mkldnn/mkldnn_fc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
NDArray data = in_data[fullc::kData];
NDArray weight = in_data[fullc::kWeight];
NDArray output = out_data[fullc::kOut];
const TShape &ishape = data.shape();
const mxnet::TShape &ishape = data.shape();
if (mkldnn_param.quantized && ishape.ndim() != 2) {
CHECK(default_param.flatten)
<< "QuantizedFullyConnected only supports flatten=true when ishape.ndim() != 2 for now.";
Expand Down Expand Up @@ -265,12 +265,12 @@ static inline void FillBaseInputOutputInfo(const FullyConnectedParam &param,
}

static bool SgMKLDNNFCInferShape(const nnvm::NodeAttrs &attrs,
std::vector<TShape> *in_shapes,
std::vector<TShape> *out_shapes) {
mxnet::ShapeVector *in_shapes,
mxnet::ShapeVector *out_shapes) {
auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
if (full_param.mkldnn_param.quantized) {
std::vector<TShape> base_in_shapes;
std::vector<TShape> base_out_shapes;
mxnet::ShapeVector base_in_shapes;
mxnet::ShapeVector base_out_shapes;
FillBaseInputOutputInfo(full_param.default_param, &base_in_shapes, &base_out_shapes,
in_shapes, out_shapes);
bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes);
Expand Down Expand Up @@ -368,7 +368,7 @@ static bool SgMKLDNNFCStorageType(const nnvm::NodeAttrs &attrs,

static OpStatePtr CreateSgMKLDNNFCState(const nnvm::NodeAttrs &attrs,
Context ctx,
const std::vector<TShape> &in_shapes,
const mxnet::ShapeVector &in_shapes,
const std::vector<int> &in_types) {
return OpStatePtr::Create<SgMKLDNNFCOp>(attrs);
}
Expand Down Expand Up @@ -414,7 +414,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_fully_connected)
.set_attr_parser(SgMKLDNNFCParamParser)
.set_attr<nnvm::FListInputNames>("FListInputNames", SgMKLDNNFCListInputNames)
.set_attr<nnvm::FListOutputNames>("FListOutputNames", SgMKLDNNFCListOutputNames)
.set_attr<nnvm::FInferShape>("FInferShape", SgMKLDNNFCInferShape)
.set_attr<mxnet::FInferShape>("FInferShape", SgMKLDNNFCInferShape)
.set_attr<nnvm::FInferType>("FInferType", SgMKLDNNFCInferType)
.set_attr<FInferStorageType>("FInferStorageType", SgMKLDNNFCStorageType)
.set_attr<FCreateOpState>("FCreateOpState", CreateSgMKLDNNFCState)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/mkl/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def check_qsym_calibrated(qsym, out_type, name='conv'):
if k.find('_quantize') != -1:
assert v['out_type'] == out_type
if k.find(quantized_op_name) != -1:
if name == 'fc' and 'fuse_dequantize' in v:
if name == 'fc' and 'enable_float_output' in v:
continue
assert 'min_calib_range' in v
assert 'max_calib_range' in v
Expand Down Expand Up @@ -155,7 +155,7 @@ def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True
for k, v in sym_sg.attr_dict().items():
if k.find(op_name) != -1:
for attr_op in attrs_op:
assert v[attr_op] == 'true'
assert v[attr_op] in ['true', 'True']

arg_shapes, _, aux_shapes = sym.infer_shape()
arg_array = [mx.nd.random.uniform(-1, 1, shape=shape) for shape in arg_shapes]
Expand Down

0 comments on commit b7d8324

Please sign in to comment.