Skip to content

Commit

Permalink
Register fake grad to subgraph and quantized operators (apache#14275)
Browse files Browse the repository at this point in the history
* add fake grad

* Skip inference only subgraph pass when gradient is needed.

* add fake grad to quantizev2

* add TODO

* modify prop_name to property_name

* add test case
  • Loading branch information
xinyu-intel authored and vdantu committed Mar 31, 2019
1 parent 777d678 commit e8a657a
Show file tree
Hide file tree
Showing 15 changed files with 114 additions and 21 deletions.
34 changes: 27 additions & 7 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1506,8 +1506,26 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& aux_state_ctxes) {
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types) {
auto subgraph_prop = op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
bool need_grad = false;
for (OpReqType req : grad_req_types) {
if (req != kNullOp) {
need_grad = true;
break;
}
}
if (subgraph_prop->HasAttr("inference_only") &&
subgraph_prop->GetAttr<bool>("inference_only") == true) {
if (need_grad) {
auto full_name = subgraph_prop->HasAttr("prop_name")
? subgraph_prop->GetAttr<std::string>("prop_name")
: prop_name;
LOG(INFO) << "Skip subgraph " << full_name << " as it requires `grad_req=null`.";
return src;
}
}
nnvm::Symbol ret = src.Copy();
nnvm::Graph g;
g.outputs = ret.outputs;
Expand Down Expand Up @@ -1539,7 +1557,8 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& aux_state_ctxes) {
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types) {
const std::vector<std::string> input_names = src.ListInputNames(Symbol::kAll);
mxnet::ShapeVector arg_shapes(input_names.size(), mxnet::TShape());
nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
Expand All @@ -1559,7 +1578,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
}
}
return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes, grad_req_types);
}

// Given input ndarrays, partition the graph using the backend name equal to prop_name.
Expand All @@ -1569,7 +1588,8 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
std::vector<NDArray> *in_args,
const std::vector<NDArray> &aux_states,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map) {
const std::map<std::string, Context>& ctx_map,
const std::vector<OpReqType>& grad_req_types) {
const std::vector<std::string> input_names = src.ListInputNames(Symbol::kAll);
const std::vector<std::string> arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
const std::vector<std::string> aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
Expand Down Expand Up @@ -1609,7 +1629,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
in_args_map[arg_names[i]] = in_args->at(i);
}
auto result = PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
ctx_map, in_arg_ctxes, aux_state_ctxes);
ctx_map, in_arg_ctxes, aux_state_ctxes, grad_req_types);
// Reorder in_args into new_in_args according to partitioned symbol input sequence
std::vector<NDArray> new_in_args(in_args->size());
// get new symbol in_arg names
Expand Down Expand Up @@ -1644,7 +1664,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
if (!exec->subgraph_property().empty()) {
symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map,
arg_stype_map, default_ctx, group2ctx, in_arg_ctxes,
aux_state_ctxes);
aux_state_ctxes, grad_req_types);
}
exec->Init(symbol, default_ctx, group2ctx,
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
Expand All @@ -1667,7 +1687,7 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
std::vector<NDArray> tmp_in_args = in_args;
if (!exec->subgraph_property().empty()) {
symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), &tmp_in_args, aux_states,
default_ctx, group2ctx);
default_ctx, group2ctx, grad_req_type);
}
exec->Init(symbol, default_ctx, group2ctx,
tmp_in_args, arg_grad_store, grad_req_type, aux_states,
Expand Down
3 changes: 3 additions & 0 deletions src/operator/quantization/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ by keep zero centered for the quantized value:
.set_attr<mxnet::FInferShape>("FInferShape", DequantizeShape)
.set_attr<nnvm::FInferType>("FInferType", DequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", DequantizeStorageType)
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNDequantizeCompute)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/quantization/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ where
.set_attr<mxnet::FInferShape>("FInferShape", QuantizeShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizeType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizeStorageType)
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeCompute)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/quantization/quantize_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ If min_calib_range isn't presented, the output type will be int8.
.set_attr<mxnet::FInferShape>("FInferShape", QuantizeV2Shape)
.set_attr<nnvm::FInferType>("FInferType", QuantizeV2Type)
.set_attr<FInferStorageType>("FInferStorageType", QuantizeV2StorageType)
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeV2Compute)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/quantization/quantized_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ If any input holds int8, then the output will be int8. Otherwise output will be
.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"output", "min_output", "max_output"};
})
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
.set_attr<std::string>("key_var_num_args", "num_args")
Expand Down
3 changes: 3 additions & 0 deletions src/operator/quantization/quantized_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ and max thresholds representing the threholds for quantizing the float32 output
.set_attr<mxnet::FInferShape>("FInferShape", QuantizedConvShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedConvType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizedConvStorageType)
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>(1, ResourceRequest::kTempSpace);
Expand Down
3 changes: 3 additions & 0 deletions src/operator/quantization/quantized_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ NNVM_REGISTER_OP(_contrib_quantized_flatten)
.set_attr<mxnet::FInferShape>("FInferShape", QuantizedFlattenShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedFlattenType)
.set_attr<FCompute>("FCompute<cpu>", QuantizedFlattenCompute<cpu>)
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "min_data", "max_data"};
Expand Down
3 changes: 3 additions & 0 deletions src/operator/quantization/quantized_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ and max thresholds representing the threholds for quantizing the float32 output
.set_attr<mxnet::FInferShape>("FInferShape", QuantizedFullyConnectedShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizedFullyConnectedStorageType)
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
.set_attr<FComputeEx>("FComputeEx<cpu>",
QuantizedFullyConnectedForward<int8_t>)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/quantization/quantized_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ the float32 data into int8.
.set_attr<mxnet::FInferShape>("FInferShape", QuantizedPoolingShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedPoolingType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizedPoolingStorageType)
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FNeedRequantize>("FNeedRequantize",
[](const NodeAttrs& attrs) {
const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
Expand Down
3 changes: 3 additions & 0 deletions src/operator/quantization/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ inference accuracy.
.set_attr<mxnet::FInferShape>("FInferShape", QuantizeShape)
.set_attr<nnvm::FInferType>("FInferType", RequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", RequantizeStorageType)
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNRequantizeForward)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/subgraph/mkldnn/mkldnn_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,9 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv)
.set_attr<FInferStorageType>("FInferStorageType", SgMKLDNNConvOpStorageType)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SgMKLDNNConvOpForward)
.set_attr<bool>("TIsMKLDNN", true)
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ class SgMKLDNNConvPostQuantizeProperty : public SubgraphProperty {
}
}
static SubgraphPropertyPtr Create() {
return std::make_shared<SgMKLDNNConvPostQuantizeProperty>();
auto property = std::make_shared<SgMKLDNNConvPostQuantizeProperty>();
property->SetAttr<std::string>("property_name",
"MKLDNN Convolution post-quantization optimization pass");
property->SetAttr<bool>("inference_only", true);
return property;
}
nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
const int subgraph_id = 0) const override {
Expand Down
5 changes: 4 additions & 1 deletion src/operator/subgraph/mkldnn/mkldnn_conv_property.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ class SgMKLDNNConvProperty : public SubgraphProperty {
}
}
static SubgraphPropertyPtr Create() {
return std::make_shared<SgMKLDNNConvProperty>();
auto property = std::make_shared<SgMKLDNNConvProperty>();
property->SetAttr<std::string>("prop_name", "MKLDNN Convolution optimization pass");
property->SetAttr<bool>("inference_only", true);
return property;
}
nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
const int subgraph_id = 0) const override {
Expand Down
7 changes: 7 additions & 0 deletions src/operator/subgraph/subgraph_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ class SubgraphProperty {
CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in SubgraphProperty";
return nnvm::get<T>(*it->second);
}
/*!
* \brief Check if the attr exist.
*/
bool HasAttr(const std::string& name) const {
auto it = attrs_.find(name);
return it != attrs_.end();
}

protected:
std::unordered_map<std::string, std::shared_ptr<nnvm::any>> attrs_;
Expand Down
53 changes: 41 additions & 12 deletions tests/python/mkl/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,37 @@ def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape):
output.wait_to_read()
return mod.get_outputs()

def check_quantize(sym, data_shape, out_type, check_conv=True):
def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape):
# save qsym to JSON file
qsym.save('quantized-symbol.json')
# save params
save_dict = {('arg:%s' % k): v.as_in_context(mx.current_context()) for k, v in qarg_params.items()}
save_dict.update({('aux:%s' % k): v.as_in_context(mx.current_context()) for k, v in qaux_params.items()})
mx.nd.save('quantized-0000.params', save_dict)
# load back with SymbolBlock
net = mx.gluon.SymbolBlock.imports('quantized-symbol.json', ['data'], 'quantized-0000.params')
net.collect_params().reset_ctx(ctx = mx.current_context())
net.hybridize()

data = mx.random.uniform(-1.0, 1.0, shape=data_shape)
net(data)

def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=False):
fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc')
sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
sym_sg = sym.get_backend_symbol("MKLDNN")
label_shape = (data_shape[0], 10)
mod = Module(symbol=sym)
mod.bind(for_training=False,
data_shapes=[('data', data_shape)],
label_shapes=[('softmax_label', label_shape)])
if gluon_forward == True:
sym = fc
sym_sg = fc.get_backend_symbol("MKLDNN")
mod = Module(symbol=sym, label_names=[])
mod.bind(for_training=False,
data_shapes=[('data', data_shape)])
else:
sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
sym_sg = sym.get_backend_symbol("MKLDNN")
label_shape = (data_shape[0], 10)
mod = Module(symbol=sym)
mod.bind(for_training=False,
data_shapes=[('data', data_shape)],
label_shapes=[('softmax_label', label_shape)])
mod.init_params(mx.init.Normal(0.5))
arg_params, aux_params = mod.get_params()

Expand Down Expand Up @@ -107,10 +129,13 @@ def check_quantize(sym, data_shape, out_type, check_conv=True):
qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE")
if check_conv:
check_qsym_calibrated(qsym, out_type)
quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape)
for i in range(len(ref_out)):
assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1)
check_qsym_dummy_forward(qsym, batch, data_shape, label_shape)
if gluon_forward == True:
check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape)
else:
check_qsym_dummy_forward(qsym, batch, data_shape, label_shape)
quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape)
for i in range(len(ref_out)):
assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1)


@with_seed()
Expand All @@ -137,6 +162,7 @@ def check_fusion(sym, data_shape, attrs_op):
# fp32 to int8
for out_type in ('uint8', 'int8', 'auto'):
check_quantize(sym, data_shape, out_type)
check_quantize(sym, data_shape, out_type, gluon_forward=True)

def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10)):
for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs):
Expand Down Expand Up @@ -478,10 +504,13 @@ def test_pos_single_concat():
for out_type in ('uint8', 'int8', 'auto'):
net = single_concat(data_shape, 2, 1)
check_quantize(net, data_shape, out_type, False)
check_quantize(net, data_shape, out_type, False, True)
net = single_concat(data_shape, 4, 2)
check_quantize(net, data_shape, out_type, False)
check_quantize(net, data_shape, out_type, False, True)
net = single_concat(data_shape, 4, 3)
check_quantize(net, data_shape, out_type, False)
check_quantize(net, data_shape, out_type, False, True)

@with_seed()
def test_neg_conv_bn():
Expand Down

0 comments on commit e8a657a

Please sign in to comment.