-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Graph partitioner and subgraph op #11251
Changes from all commits
4b2f049
f343b0c
d13e489
0ea2763
89f321a
1e17e87
a337e09
d6bf59c
7a39d7f
10b67b6
051c184
c585398
382c792
59652ef
22d2c9d
c81da86
31cc8fc
1293828
987f145
c95e912
7e96475
8e3d4ed
a642047
63dcbcf
575ef94
41a97c8
0227c27
a9cf101
d71467b
cf7c4fe
50d0e60
dcf2750
6e29456
c6d6c3b
d24900c
5660a85
5088603
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../image-classification/common |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import argparse | ||
import logging | ||
import os | ||
import time | ||
import mxnet as mx | ||
from common import modelzoo | ||
from mxnet import nd | ||
from mxnet.contrib.quantization import * | ||
from mxnet.base import _LIB | ||
|
||
|
||
def download_dataset(dataset_url, dataset_dir, logger=None): | ||
if logger is not None: | ||
logger.info('Downloading dataset for inference from %s to %s' % (dataset_url, dataset_dir)) | ||
mx.test_utils.download(dataset_url, dataset_dir) | ||
|
||
|
||
def download_model(model_name, logger=None): | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
model_path = os.path.join(dir_path, 'model') | ||
if logger is not None: | ||
logger.info('Downloading model %s... into path %s' % (model_name, model_path)) | ||
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model')) | ||
|
||
|
||
def advance_data_iter(data_iter, n): | ||
assert n >= 0 | ||
if n == 0: | ||
return data_iter | ||
has_next_batch = True | ||
while has_next_batch: | ||
try: | ||
data_iter.next() | ||
n -= 1 | ||
if n == 0: | ||
return data_iter | ||
except StopIteration: | ||
has_next_batch = False | ||
|
||
|
||
def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, logger=None): | ||
metrics = [mx.metric.create('acc'), | ||
mx.metric.create('top_k_accuracy', top_k=5)] | ||
if not isinstance(metrics, list): | ||
metrics = [metrics, ] | ||
mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name, ]) | ||
mod.bind(for_training=False, | ||
data_shapes=data.provide_data, | ||
label_shapes=data.provide_label) | ||
mod.set_params(arg_params, aux_params) | ||
|
||
tic = time.time() | ||
num = 0 | ||
for batch in data: | ||
mod.forward(batch, is_train=False) | ||
for m in metrics: | ||
mod.update_metric(m, batch.label) | ||
num += batch_size | ||
if max_num_examples is not None and num >= max_num_examples: | ||
break | ||
|
||
speed = num / (time.time() - tic) | ||
|
||
if logger is not None: | ||
logger.info('Finished inference with %d images' % num) | ||
logger.info('Finished with %f images per second', speed) | ||
for m in metrics: | ||
logger.info(m.get()) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Score a model on a dataset') | ||
parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'], | ||
help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn') | ||
parser.add_argument('--batch-size', type=int, default=32) | ||
parser.add_argument('--label-name', type=str, default='softmax_label') | ||
parser.add_argument('--dataset', type=str, required=True, help='dataset path') | ||
parser.add_argument('--rgb-mean', type=str, default='0,0,0') | ||
parser.add_argument('--image-shape', type=str, default='3,224,224') | ||
parser.add_argument('--data-nthreads', type=int, default=60, help='number of threads for data decoding') | ||
parser.add_argument('--num-skipped-batches', type=int, default=0, help='skip the number of batches for inference') | ||
parser.add_argument('--num-inference-batches', type=int, required=True, help='number of images used for inference') | ||
parser.add_argument('--shuffle-dataset', action='store_true', default=True, | ||
help='shuffle the calibration dataset') | ||
parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304, | ||
help='shuffling chunk seed, see' | ||
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' | ||
' for more details') | ||
parser.add_argument('--shuffle-seed', type=int, default=48564309, | ||
help='shuffling seed, see' | ||
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' | ||
' for more details') | ||
|
||
args = parser.parse_args() | ||
|
||
logging.basicConfig() | ||
logger = logging.getLogger('logger') | ||
logger.setLevel(logging.INFO) | ||
data_nthreads = args.data_nthreads | ||
batch_size = args.batch_size | ||
logger.info('batch size = %d for inference' % batch_size) | ||
|
||
rgb_mean = args.rgb_mean | ||
logger.info('rgb_mean = %s' % rgb_mean) | ||
rgb_mean = [float(i) for i in rgb_mean.split(',')] | ||
mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]} | ||
|
||
label_name = args.label_name | ||
logger.info('label_name = %s' % label_name) | ||
|
||
image_shape = args.image_shape | ||
data_shape = tuple([int(i) for i in image_shape.split(',')]) | ||
logger.info('Input data shape = %s' % str(data_shape)) | ||
|
||
dataset = args.dataset | ||
download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset) | ||
logger.info('Dataset for inference: %s' % dataset) | ||
|
||
# creating data iterator | ||
data = mx.io.ImageRecordIter(path_imgrec=dataset, | ||
label_width=1, | ||
preprocess_threads=data_nthreads, | ||
batch_size=batch_size, | ||
data_shape=data_shape, | ||
label_name=label_name, | ||
rand_crop=False, | ||
rand_mirror=False, | ||
shuffle=True, | ||
shuffle_chunk_seed=3982304, | ||
seed=48564309, | ||
**mean_args) | ||
|
||
# download model | ||
prefix, epoch = download_model(model_name=args.model, logger=logger) | ||
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) | ||
op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation'] | ||
out = SymbolHandle() | ||
check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)), c_str_array(op_names), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about wrapping this function into Symbol class? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will hide this from users. The API design is still under discussion. |
||
ctypes.byref(out))) | ||
psym = Symbol(out) | ||
|
||
# make sure that fp32 inference works on the same images as calibrated quantized model | ||
logger.info('Skipping the first %d batches' % args.num_skipped_batches) | ||
data = advance_data_iter(data, args.num_skipped_batches) | ||
|
||
num_inference_images = args.num_inference_batches * batch_size | ||
logger.info('Running model %s for inference' % args.model) | ||
score(psym, arg_params, aux_params, data, [mx.gpu(0)], label_name, | ||
max_num_examples=num_inference_images, logger=logger) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,7 +98,12 @@ enum class ExecType { | |
* In current implementation, copy operator is specially handled by executor. | ||
* This flag is used for special case treatment and future extension of different copy ops. | ||
*/ | ||
kCrossDeviceCopy | ||
kCrossDeviceCopy, | ||
/*! | ||
* A subgraph execution should happen in the main thread, instead of | ||
* in the execution engine. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. operators inside a subgraph are still executed in the execution engine, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
*/ | ||
kSubgraphExec, | ||
}; | ||
|
||
/*! \brief the dispatch mode of the operator */ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ | |
#include "./c_api_common.h" | ||
#include "../operator/operator_common.h" | ||
#include "../executor/exec_pass.h" | ||
#include "../operator/subgraph/default_subgraph_op.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
@@ -625,3 +626,27 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, | |
*ret_qsym_handle = s; | ||
API_END_HANDLE_ERROR(delete s); | ||
} | ||
|
||
int MXPartitionGraph(SymbolHandle sym_handle, | ||
const mx_uint num_ops, | ||
const char** op_names, | ||
SymbolHandle* ret_sym_handle) { | ||
nnvm::Symbol* s = new nnvm::Symbol(); | ||
API_BEGIN(); | ||
std::unordered_set<std::string> op_name_set; | ||
for (size_t i = 0; i < num_ops; ++i) { | ||
op_name_set.emplace(op_names[i]); | ||
} | ||
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle); | ||
*s = sym->Copy(); | ||
nnvm::Graph g = Symbol2Graph(*s); | ||
if (!op_name_set.empty()) { | ||
mxnet::op::SubgraphPropertyPtr property | ||
= std::make_shared<mxnet::op::DefaultSubgraphProperty>(op_name_set); | ||
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible for mxnet to enable several acceleration backends at the same time? So then there will be several different There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not exactly. Your implementation of SubgraphProperty determines what subgraph op node is used and consequently, what Forward/Backward implementation is adopted. If you want to support multiple backend at the same time, you just need to configure that in the implementation of Forward/Backward. The question is, how do you determine when to use which backend? |
||
} | ||
g = ApplyPass(std::move(g), "PartitionGraph"); | ||
s->outputs = g.outputs; | ||
*ret_sym_handle = s; | ||
API_END_HANDLE_ERROR(delete s); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,8 +33,12 @@ | |
namespace mxnet { | ||
namespace engine { | ||
|
||
#if 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for debug? |
||
/*! \brief base class of engine variables, used for type checking */ | ||
struct Var { | ||
virtual uint32_t version() { | ||
return version_; | ||
} | ||
#if ENGINE_DEBUG | ||
virtual ~Var() = default; | ||
#endif // ENGINE_DEBUG | ||
|
@@ -45,7 +49,13 @@ struct Var { | |
*/ | ||
template <typename T> | ||
inline T* Cast(); | ||
/*! | ||
* \brief version number of the var. Every time the object it is associated with | ||
* is modified, the version number is incremented by 1. | ||
*/ | ||
uint32_t version_{0}; | ||
}; // struct Var | ||
#endif | ||
|
||
/*! \brief base class of engine operators, used for type checking */ | ||
struct Opr { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,10 +28,24 @@ | |
#include "./engine_impl.h" | ||
#include "../profiler/profiler.h" | ||
#include "./openmp.h" | ||
#include "../common/object_pool.h" | ||
|
||
namespace mxnet { | ||
namespace engine { | ||
|
||
/*! | ||
* \brief var used in Naive Engine for tracking the version | ||
* of the objects it is associated with. | ||
*/ | ||
class NaiveVar final | ||
: public Var, public common::ObjectPoolAllocatable<NaiveVar> { | ||
public: | ||
inline static NaiveVar* CastFromBase(Var* ptr) { | ||
return ptr->Cast<NaiveVar>(); | ||
} | ||
}; // class NaiveVar | ||
|
||
|
||
// implement naive engine | ||
class NaiveEngine final : public Engine { | ||
public: | ||
|
@@ -71,8 +85,11 @@ class NaiveEngine final : public Engine { | |
|
||
// new variables | ||
VarHandle NewVariable() override { | ||
return NaiveVar::New(); | ||
#if 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for debug? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. |
||
size_t v = ++counter_; | ||
return reinterpret_cast<VarHandle>(v); | ||
#endif | ||
} | ||
|
||
OprHandle NewOperator(AsyncFn fn, | ||
|
@@ -165,14 +182,26 @@ class NaiveEngine final : public Engine { | |
} | ||
CHECK(this->req_completed_) | ||
<< "NaiveEngine only support synchronize Push so far"; | ||
// increment var version | ||
for (auto var : mutable_vars) { | ||
++var->version_; | ||
} | ||
if (profiling) { | ||
opr->opr_profile->stop(); | ||
} | ||
} | ||
|
||
void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override { | ||
NaiveVar* naive_var = NaiveVar::CastFromBase(var); | ||
this->PushAsync([delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable { | ||
delete_fn(ctx); | ||
NaiveVar::Delete(naive_var); | ||
on_complete(); | ||
}, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0, "DeleteVariable"); | ||
#if 0 | ||
this->PushSync(delete_fn, exec_ctx, {}, {var}, | ||
FnProperty::kNormal, 0, "DeleteVariable"); | ||
#endif | ||
} | ||
|
||
void WaitForVar(VarHandle var) override { | ||
|
@@ -192,8 +221,6 @@ class NaiveEngine final : public Engine { | |
} | ||
// whether action is completed | ||
bool req_completed_; | ||
// counter | ||
std::atomic<size_t> counter_{0}; | ||
/*! \brief whether it is during shutdown phase*/ | ||
std::atomic<bool> shutdown_phase_{false}; | ||
// CPU stream | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you write a simple README for this file? How to run this test? How can users play with this test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the following command to try it. Please remember to change
ctx
passed to the functionscore
intomx.cpu()
if you are running on a cpu. It's just a preliminary end-to-end run-through for proof of concept. More changes are coming.