Skip to content

Commit

Permalink
refer to [PR Allow operators with multiple outputs in get_atomic_symbol
Browse files Browse the repository at this point in the history
  • Loading branch information
joapolarbear committed Oct 28, 2020
1 parent 87fe065 commit f9f2b1e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
14 changes: 7 additions & 7 deletions benchmark/python/gluon/benchmark_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
'By default, use CPU only.')
parser.add_argument('--type', type=str, default='inference', choices=['all', 'training', 'inference'])

opt = parser.parse_args()
args = parser.parse_args()

num_batches = opt.num_batches
num_batches = args.num_batches
dry_run = 10 # use 10 iterations to warm up
batch_inf = [1, 32, 64, 128, 256]
batch_train = [1, 32, 64, 128, 256]
Expand Down Expand Up @@ -116,10 +116,10 @@ def train(network, batch_size, ctx):
return bwd

if __name__ == '__main__':
runtype = opt.type
bs = opt.batch_size
runtype = args.type
bs = args.batch_size

if opt.model == 'all':
if args.model == 'all':
networks = ['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
'inceptionv3', 'mobilenet0.25', 'mobilenet0.5', 'mobilenet0.75',
'mobilenet1.0', 'mobilenetv2_0.25', 'mobilenetv2_0.5', 'mobilenetv2_0.75',
Expand All @@ -130,9 +130,9 @@ def train(network, batch_size, ctx):
logging.info('It may take some time to run all models, '
'set --network to run a specific one')
else:
networks = [opt.model]
networks = [args.model]

devs = [mx.gpu(int(i)) for i in opt.gpus.split(',')] if opt.gpus.strip() else [mx.cpu()]
devs = [mx.gpu(int(i)) for i in args.gpus.split(',')] if args.gpus.strip() else [mx.cpu()]
num_gpus = len(devs)

for network in networks:
Expand Down
12 changes: 8 additions & 4 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -854,10 +854,14 @@ int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_h
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *source = static_cast<nnvm::Symbol *>(sym_handle);
CHECK_EQ(source->outputs.size(), 1U)
<< "Generating atomic symbol from other symbol only works for nongrouped symbol.";
const auto& node = source->outputs[0];
const auto *op = node.node->op();
CHECK_GE(source->outputs.size(), 1) << "Input symbol does not have outputs.";
const auto &node = source->outputs[0].node;
for (const auto &other_node : source->outputs) {
if (node.get() != other_node.node.get()) {
LOG(FATAL) << "Generating atomic symbol from other symbol only works for nongrouped symbol.";
}
}
const auto *op = node->op();
const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow);
*s = nnvm::Symbol::CreateFunctor(op, attrs);
*ret_sym_handle = s;
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,15 @@ def test_children_same_name():
for c in b.get_children():
pass

def test_gen_atomic_symbol_multiple_outputs():
data=mx.sym.Variable('data')
p = mx.sym.Variable('param')
h0 = mx.sym.Variable('h0')
h1 = mx.sym.Variable('h1')
s = mx.sym.RNN(data, p, h0, h1, state_size=10, num_layers=2,
bidirectional=True, state_outputs=True, mode='lstm')
atomic_sym = s._gen_atomic_symbol()

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit f9f2b1e

Please sign in to comment.