Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[DO NOT REVIEW] Subgraph API (#12104)
Browse files Browse the repository at this point in the history
* Initial commit

* Add unit tests

* Fix lint

* Fix lint

* Clean up

* Add graph partitiong to Bind

* Add property name to graph partitioning c api

* Fix unit test gpu context

* Address cr

* Move subgraph to attrs.subgraphs and fix the example

* Fix lint

* Add var version unit test

* Address cr

* Enable unit test that was flaky
  • Loading branch information
reminisce authored Aug 14, 2018
1 parent 6acf6cc commit 4c1933e
Show file tree
Hide file tree
Showing 21 changed files with 720 additions and 197 deletions.
31 changes: 24 additions & 7 deletions example/subgraph_op/imagenet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Score a model on a dataset')
parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
parser.add_argument('--model', type=str, required=True,
choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--label-name', type=str, default='softmax_label')
Expand All @@ -107,6 +108,8 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
help='shuffling seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
parser.add_argument('--subgraph-backend', type=str, default='default', help='subgraph backend name.')
parser.add_argument('--ctx', type=str, default='cpu')

args = parser.parse_args()

Expand All @@ -133,6 +136,15 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
logger.info('Dataset for inference: %s' % dataset)

subgraph_backend = args.subgraph_backend

if args.ctx == 'cpu':
ctx = mx.cpu()
elif args.ctx == 'gpu':
ctx = mx.gpu(0)
else:
raise ValueError('unknown ctx option, only cpu and gpu are supported')

# creating data iterator
data = mx.io.ImageRecordIter(path_imgrec=dataset,
label_width=1,
Expand All @@ -151,16 +163,21 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation']
out = SymbolHandle()
check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)), c_str_array(op_names),
ctypes.byref(out)))
psym = Symbol(out)

if subgraph_backend is not None:
os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
if subgraph_backend == 'default':
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
c_str_array(op_names)))
# make sure that fp32 inference works on the same images as calibrated quantized model
logger.info('Skipping the first %d batches' % args.num_skipped_batches)
data = advance_data_iter(data, args.num_skipped_batches)

num_inference_images = args.num_inference_batches * batch_size
logger.info('Running model %s for inference' % args.model)
score(psym, arg_params, aux_params, data, [mx.gpu(0)], label_name,
score(sym, arg_params, aux_params, data, [ctx], label_name,
max_num_examples=num_inference_images, logger=logger)

if subgraph_backend is not None:
del os.environ['MXNET_SUBGRAPH_BACKEND']
if subgraph_backend == 'default':
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
5 changes: 0 additions & 5 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1479,11 +1479,6 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
const float* high_quantiles,
SymbolHandle* ret_sym_handle);

MXNET_DLL int MXPartitionGraph(SymbolHandle sym_handle,
const mx_uint num_ops,
const char** op_names,
SymbolHandle* ret_sym_handle);

//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
Expand Down
66 changes: 66 additions & 0 deletions include/mxnet/c_api_test.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2018 by Contributors
* \file c_api_test.h
* \brief C API of mxnet for ease of testing backend in Python
*/
#ifndef MXNET_C_API_TEST_H_
#define MXNET_C_API_TEST_H_

/*! \brief Inhibit C++ name-mangling for MXNet functions. */
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus

#include <mxnet/c_api.h>

/*!
* \brief This API partitions a graph only by the operator names
* provided by users. This will attach a DefaultSubgraphProperty
* to the input graph for partitioning. This function should be
* used only for the testing purpose.
*/
MXNET_DLL int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
const char* prop_name,
const mx_uint num_ops,
const char** op_names,
SymbolHandle* ret_sym_handle);

/*!
* \brief Given a subgraph property name, use the provided op names
* as the op_names attribute for that subgraph property, instead of
* the predefined one. This is only for the purpose of testing.
*/
MXNET_DLL int MXSetSubgraphPropertyOpNames(const char* prop_name,
const mx_uint num_ops,
const char** op_names);

/*!
* \brief Given a subgraph property name, delete the op name set
* in the SubgraphPropertyOpNameSet.
*/
MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);

#ifdef __cplusplus
}
#endif // __cplusplus

#endif // MXNET_C_API_TEST_H_
4 changes: 2 additions & 2 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Engine;
namespace engine {
/*! \brief base class of engine variables.*/
struct Var {
virtual uint32_t version() {
virtual size_t version() {
return version_;
}
virtual ~Var() = default;
Expand All @@ -58,7 +58,7 @@ struct Var {
* \brief version number of the var. Every time the object it is associated with
* is modified, the version number is incremented by 1.
*/
uint32_t version_{0};
size_t version_{0};
}; // struct Var

/*! \brief Internal representation of operator. */
Expand Down
25 changes: 0 additions & 25 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "./c_api_common.h"
#include "../operator/operator_common.h"
#include "../executor/exec_pass.h"
#include "../operator/subgraph/default_subgraph_op.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -697,27 +696,3 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
*ret_qsym_handle = s;
API_END_HANDLE_ERROR(delete s);
}

int MXPartitionGraph(SymbolHandle sym_handle,
const mx_uint num_ops,
const char** op_names,
SymbolHandle* ret_sym_handle) {
nnvm::Symbol* s = new nnvm::Symbol();
API_BEGIN();
std::unordered_set<std::string> op_name_set;
for (size_t i = 0; i < num_ops; ++i) {
op_name_set.emplace(op_names[i]);
}
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
*s = sym->Copy();
nnvm::Graph g = Symbol2Graph(*s);
if (!op_name_set.empty()) {
mxnet::op::SubgraphPropertyPtr property
= std::make_shared<mxnet::op::DefaultSubgraphProperty>(op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
}
g = ApplyPass(std::move(g), "PartitionGraph");
s->outputs = g.outputs;
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
73 changes: 73 additions & 0 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2018 by Contributors
* \file c_api_test.cc
* \brief C API of mxnet for the ease of testing backend in Python
*/
#include <mxnet/c_api_test.h>
#include <nnvm/pass.h>
#include "./c_api_common.h"
#include "../operator/subgraph/default_subgraph_property.h"

int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
const char* prop_name,
const mx_uint num_ops,
const char** op_names,
SymbolHandle* ret_sym_handle) {
nnvm::Symbol* s = new nnvm::Symbol();
API_BEGIN();
std::unordered_set<std::string> op_name_set;
for (size_t i = 0; i < num_ops; ++i) {
op_name_set.emplace(op_names[i]);
}
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
*s = sym->Copy();
nnvm::Graph g;
g.outputs = s->outputs;
if (!op_name_set.empty()) {
mxnet::op::SubgraphPropertyPtr property
= mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
property->SetAttr("op_names", op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
}
g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
s->outputs = g.outputs;
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}

int MXSetSubgraphPropertyOpNames(const char* prop_name,
const mx_uint num_ops,
const char** op_names) {
API_BEGIN();
std::unordered_set<std::string> op_name_set;
for (size_t i = 0; i < num_ops; ++i) {
op_name_set.emplace(op_names[i]);
}
(*mxnet::op::SubgraphPropertyOpNameSet::Get())[prop_name] = op_name_set;
API_END();
}

int MXRemoveSubgraphPropertyOpNames(const char* prop_name) {
API_BEGIN();
mxnet::op::SubgraphPropertyOpNameSet::Get()->erase(prop_name);
API_END();
}
4 changes: 0 additions & 4 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ class NaiveEngine final : public Engine {
// new variables
VarHandle NewVariable() override {
return NaiveVar::New();
#if 0
size_t v = ++counter_;
return reinterpret_cast<VarHandle>(v);
#endif
}

OprHandle NewOperator(AsyncFn fn,
Expand Down
2 changes: 1 addition & 1 deletion src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ inline bool ThreadedVar::ready_to_read() {
return this->is_ready_to_read();
}

inline uint32_t ThreadedVar::version() {
inline size_t ThreadedVar::version() {
std::lock_guard<std::mutex> lock{mutex_};
return this->version_;
}
Expand Down
2 changes: 1 addition & 1 deletion src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class ThreadedVar final
inline void SetToDelete();
/*! \return whether this variable is ready to read. */
inline bool ready_to_read();
inline uint32_t version() override;
inline size_t version() override;
/*!
* \brief Cast a Var pointer to ThreadedVar pointer
* \param ptr pointer from base.
Expand Down
Loading

0 comments on commit 4c1933e

Please sign in to comment.