Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Graph partitioner and subgraph op #11251

Merged
merged 37 commits into from
Jun 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4b2f049
subgraph op
reminisce May 18, 2018
f343b0c
Add treating whole graph as a subgraph
reminisce May 18, 2018
d13e489
Add forward executor in subgraph op forward function
reminisce May 20, 2018
0ea2763
Add unit test
reminisce May 20, 2018
89f321a
Add finding subgraphs
reminisce May 25, 2018
1e17e87
Add find input/output entries of a subgraph
reminisce May 25, 2018
a337e09
Add get input/output entries
reminisce May 25, 2018
d6bf59c
Add debug info
reminisce May 26, 2018
7a39d7f
Handle corner case when a subgraph node is the last node of the graph
reminisce May 26, 2018
10b67b6
Add var version for ndarray
reminisce May 30, 2018
051c184
Fix unit test
reminisce May 31, 2018
c585398
generalize graph partitioning. (#11)
zheng-da Jun 1, 2018
382c792
change the interface (#13)
zheng-da Jun 2, 2018
59652ef
Reorder input/output entries to follow original topological sort order
reminisce Jun 2, 2018
22d2c9d
Make subgraph op node name unique
reminisce Jun 2, 2018
c81da86
Fix sorting entries
reminisce Jun 4, 2018
31cc8fc
Allow for adjacent subgraphs
reminisce Jun 5, 2018
1293828
Enable partitioning algo to eliminate cycles
reminisce Jun 6, 2018
987f145
fix shape/dtype/storage inference. (#15)
zheng-da Jun 6, 2018
c95e912
Remove graph executor from subgraph op forward
reminisce Jun 6, 2018
7e96475
Naive engine end to end run through
reminisce Jun 7, 2018
8e3d4ed
execute subgraph operators synchronously (#16)
zheng-da Jun 8, 2018
a642047
Fix graph partitioning bug and add var version for naive engine
reminisce Jun 8, 2018
63dcbcf
Refactor
reminisce Jun 11, 2018
575ef94
Toggle engine var version print out
reminisce Jun 11, 2018
41a97c8
Add virtual destructor to engine var
reminisce Jun 11, 2018
0227c27
Fix gpu build and add example
reminisce Jun 11, 2018
a9cf101
Fix FMutateInputs for subgraph op
reminisce Jun 11, 2018
d71467b
Add unit test for FMutateInputs attr of subgraph
reminisce Jun 11, 2018
cf7c4fe
fix DefaultSubgraphOpMutableInputs (#17)
TaoLv Jun 17, 2018
50d0e60
Add post processing algorithm for filtering out nodes
reminisce Jun 18, 2018
dcf2750
Refactor post processing
reminisce Jun 19, 2018
6e29456
Change file names
reminisce Jun 19, 2018
c6d6c3b
Address cr
reminisce Jun 20, 2018
d24900c
Fix lint
reminisce Jun 20, 2018
5660a85
Fix storage type vector
reminisce Jun 21, 2018
5088603
Temporarily disable test_gluon_trainer.test_trainer_reset_kv
reminisce Jun 21, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/subgraph_op/common
166 changes: 166 additions & 0 deletions example/subgraph_op/imagenet_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you write a simple README for this file? How to run this test? How can users play with this test?

Copy link
Contributor Author

@reminisce reminisce Jun 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the following command to try it. Please remember to change ctx passed to the function score into mx.cpu() if you are running on a cpu. It's just a preliminary end-to-end run-through for proof of concept. More changes are coming.

python --model=imagenet1k-resnet-152 --dataset=data --num-inference-batches=10

import argparse
import logging
import os
import time
import mxnet as mx
from common import modelzoo
from mxnet import nd
from mxnet.contrib.quantization import *
from mxnet.base import _LIB


def download_dataset(dataset_url, dataset_dir, logger=None):
if logger is not None:
logger.info('Downloading dataset for inference from %s to %s' % (dataset_url, dataset_dir))
mx.test_utils.download(dataset_url, dataset_dir)


def download_model(model_name, logger=None):
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if logger is not None:
logger.info('Downloading model %s... into path %s' % (model_name, model_path))
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))


def advance_data_iter(data_iter, n):
assert n >= 0
if n == 0:
return data_iter
has_next_batch = True
while has_next_batch:
try:
data_iter.next()
n -= 1
if n == 0:
return data_iter
except StopIteration:
has_next_batch = False


def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, logger=None):
metrics = [mx.metric.create('acc'),
mx.metric.create('top_k_accuracy', top_k=5)]
if not isinstance(metrics, list):
metrics = [metrics, ]
mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name, ])
mod.bind(for_training=False,
data_shapes=data.provide_data,
label_shapes=data.provide_label)
mod.set_params(arg_params, aux_params)

tic = time.time()
num = 0
for batch in data:
mod.forward(batch, is_train=False)
for m in metrics:
mod.update_metric(m, batch.label)
num += batch_size
if max_num_examples is not None and num >= max_num_examples:
break

speed = num / (time.time() - tic)

if logger is not None:
logger.info('Finished inference with %d images' % num)
logger.info('Finished with %f images per second', speed)
for m in metrics:
logger.info(m.get())


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Score a model on a dataset')
parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--label-name', type=str, default='softmax_label')
parser.add_argument('--dataset', type=str, required=True, help='dataset path')
parser.add_argument('--rgb-mean', type=str, default='0,0,0')
parser.add_argument('--image-shape', type=str, default='3,224,224')
parser.add_argument('--data-nthreads', type=int, default=60, help='number of threads for data decoding')
parser.add_argument('--num-skipped-batches', type=int, default=0, help='skip the number of batches for inference')
parser.add_argument('--num-inference-batches', type=int, required=True, help='number of images used for inference')
parser.add_argument('--shuffle-dataset', action='store_true', default=True,
help='shuffle the calibration dataset')
parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304,
help='shuffling chunk seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
parser.add_argument('--shuffle-seed', type=int, default=48564309,
help='shuffling seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')

args = parser.parse_args()

logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)
data_nthreads = args.data_nthreads
batch_size = args.batch_size
logger.info('batch size = %d for inference' % batch_size)

rgb_mean = args.rgb_mean
logger.info('rgb_mean = %s' % rgb_mean)
rgb_mean = [float(i) for i in rgb_mean.split(',')]
mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}

label_name = args.label_name
logger.info('label_name = %s' % label_name)

image_shape = args.image_shape
data_shape = tuple([int(i) for i in image_shape.split(',')])
logger.info('Input data shape = %s' % str(data_shape))

dataset = args.dataset
download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
logger.info('Dataset for inference: %s' % dataset)

# creating data iterator
data = mx.io.ImageRecordIter(path_imgrec=dataset,
label_width=1,
preprocess_threads=data_nthreads,
batch_size=batch_size,
data_shape=data_shape,
label_name=label_name,
rand_crop=False,
rand_mirror=False,
shuffle=True,
shuffle_chunk_seed=3982304,
seed=48564309,
**mean_args)

# download model
prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation']
out = SymbolHandle()
check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)), c_str_array(op_names),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about wrapping this function into Symbol class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will hide this from users. The API design is still under discussion.

ctypes.byref(out)))
psym = Symbol(out)

# make sure that fp32 inference works on the same images as calibrated quantized model
logger.info('Skipping the first %d batches' % args.num_skipped_batches)
data = advance_data_iter(data, args.num_skipped_batches)

num_inference_images = args.num_inference_batches * batch_size
logger.info('Running model %s for inference' % args.model)
score(psym, arg_params, aux_params, data, [mx.gpu(0)], label_name,
max_num_examples=num_inference_images, logger=logger)
5 changes: 5 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,11 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
const float* high_quantiles,
SymbolHandle* ret_sym_handle);

MXNET_DLL int MXPartitionGraph(SymbolHandle sym_handle,
const mx_uint num_ops,
const char** op_names,
SymbolHandle* ret_sym_handle);

//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
Expand Down
22 changes: 20 additions & 2 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,26 @@ class Engine;

/*! \brief namespace of engine internal types. */
namespace engine {
/*! \brief Internal representation of variable. */
struct Var;
/*! \brief base class of engine variables.*/
struct Var {
virtual uint32_t version() {
return version_;
}
virtual ~Var() = default;
/*!
* \brief cast variable to derived type T
* \tparam T the type we want to cast into.
* \return A casted variable.
*/
template <typename T>
inline T* Cast();
/*!
* \brief version number of the var. Every time the object it is associated with
* is modified, the version number is incremented by 1.
*/
uint32_t version_{0};
}; // struct Var

/*! \brief Internal representation of operator. */
struct Opr;
/*! \brief Variable pointer type, usually hold by user used to specify dependencies. */
Expand Down
4 changes: 4 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ class NDArray {
inline size_t byte_offset() const {
return byte_offset_;
}
/*! \brief return var version of the NDArray*/
inline uint32_t version() const {
return var()->version();
}
/*!
* \brief save the content into binary stream
* \param strm the output stream
Expand Down
7 changes: 6 additions & 1 deletion include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ enum class ExecType {
* In current implementation, copy operator is specially handled by executor.
* This flag is used for special case treatment and future extension of different copy ops.
*/
kCrossDeviceCopy
kCrossDeviceCopy,
/*!
* A subgraph execution should happen in the main thread, instead of
* in the execution engine.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

operators inside a subgraph are still executed in the execution engine, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

*/
kSubgraphExec,
};

/*! \brief the dispatch mode of the operator */
Expand Down
25 changes: 25 additions & 0 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "./c_api_common.h"
#include "../operator/operator_common.h"
#include "../executor/exec_pass.h"
#include "../operator/subgraph/default_subgraph_op.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -625,3 +626,27 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
*ret_qsym_handle = s;
API_END_HANDLE_ERROR(delete s);
}

int MXPartitionGraph(SymbolHandle sym_handle,
const mx_uint num_ops,
const char** op_names,
SymbolHandle* ret_sym_handle) {
nnvm::Symbol* s = new nnvm::Symbol();
API_BEGIN();
std::unordered_set<std::string> op_name_set;
for (size_t i = 0; i < num_ops; ++i) {
op_name_set.emplace(op_names[i]);
}
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
*s = sym->Copy();
nnvm::Graph g = Symbol2Graph(*s);
if (!op_name_set.empty()) {
mxnet::op::SubgraphPropertyPtr property
= std::make_shared<mxnet::op::DefaultSubgraphProperty>(op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for mxnet to enable several acceleration backends at the same time? So then there will be several different subgraph_property for different backends respectively?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. Your implementation of SubgraphProperty determines what subgraph op node is used and consequently, what Forward/Backward implementation is adopted. If you want to support multiple backend at the same time, you just need to configure that in the implementation of Forward/Backward. The question is, how do you determine when to use which backend?

}
g = ApplyPass(std::move(g), "PartitionGraph");
s->outputs = g.outputs;
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
10 changes: 10 additions & 0 deletions src/engine/engine_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@
namespace mxnet {
namespace engine {

#if 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for debug?

/*! \brief base class of engine variables, used for type checking */
struct Var {
virtual uint32_t version() {
return version_;
}
#if ENGINE_DEBUG
virtual ~Var() = default;
#endif // ENGINE_DEBUG
Expand All @@ -45,7 +49,13 @@ struct Var {
*/
template <typename T>
inline T* Cast();
/*!
* \brief version number of the var. Every time the object it is associated with
* is modified, the version number is incremented by 1.
*/
uint32_t version_{0};
}; // struct Var
#endif

/*! \brief base class of engine operators, used for type checking */
struct Opr {
Expand Down
31 changes: 29 additions & 2 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,24 @@
#include "./engine_impl.h"
#include "../profiler/profiler.h"
#include "./openmp.h"
#include "../common/object_pool.h"

namespace mxnet {
namespace engine {

/*!
* \brief var used in Naive Engine for tracking the version
* of the objects it is associated with.
*/
class NaiveVar final
: public Var, public common::ObjectPoolAllocatable<NaiveVar> {
public:
inline static NaiveVar* CastFromBase(Var* ptr) {
return ptr->Cast<NaiveVar>();
}
}; // class NaiveVar


// implement naive engine
class NaiveEngine final : public Engine {
public:
Expand Down Expand Up @@ -71,8 +85,11 @@ class NaiveEngine final : public Engine {

// new variables
VarHandle NewVariable() override {
return NaiveVar::New();
#if 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for debug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

size_t v = ++counter_;
return reinterpret_cast<VarHandle>(v);
#endif
}

OprHandle NewOperator(AsyncFn fn,
Expand Down Expand Up @@ -165,14 +182,26 @@ class NaiveEngine final : public Engine {
}
CHECK(this->req_completed_)
<< "NaiveEngine only support synchronize Push so far";
// increment var version
for (auto var : mutable_vars) {
++var->version_;
}
if (profiling) {
opr->opr_profile->stop();
}
}

void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override {
NaiveVar* naive_var = NaiveVar::CastFromBase(var);
this->PushAsync([delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable {
delete_fn(ctx);
NaiveVar::Delete(naive_var);
on_complete();
}, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0, "DeleteVariable");
#if 0
this->PushSync(delete_fn, exec_ctx, {}, {var},
FnProperty::kNormal, 0, "DeleteVariable");
#endif
}

void WaitForVar(VarHandle var) override {
Expand All @@ -192,8 +221,6 @@ class NaiveEngine final : public Engine {
}
// whether action is completed
bool req_completed_;
// counter
std::atomic<size_t> counter_{0};
/*! \brief whether it is during shutdown phase*/
std::atomic<bool> shutdown_phase_{false};
// CPU stream
Expand Down
Loading