From ccd0abdc20933448e748270cb9059433e94c4371 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 12:14:24 -0700 Subject: [PATCH] Allow partial positional arguments of input symbol --- src/symbol/symbol.cc | 29 ++++++++++++++++++----------- tests/python/train/test_mlp.py | 12 ++++++------ 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 2b923cebc0d9..fb2377dbb6b2 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -246,6 +246,16 @@ Symbol Symbol::operator[] (size_t index) const { } } +// create a default variable name +inline std::string DefaultVarName(const std::string &op_name, + const std::string &arg_name) { + if (op_name.length() == 0) { + return arg_name; + } else { + return op_name + '_' + arg_name; + } +} + void Symbol::Compose(const std::vector& args, const std::string& name) { CHECK_EQ(NumOutputs(), 1) << "Only composition of value function is supported currently"; @@ -261,13 +271,17 @@ void Symbol::Compose(const std::vector& args, if (this->is_atomic()) { // atomic symbol do not have place holder for all the arguments std::vector req_args = heads_[0].source->op->ListArguments(); - CHECK_EQ(args.size(), req_args.size()) + CHECK_LE(args.size(), req_args.size()) << "Incorrect number of arguments, requires " << req_args.size() << ", provided " << args.size(); - heads_[0].source->inputs.resize(args.size()); + heads_[0].source->inputs.resize(req_args.size()); for (size_t i = 0; i < args.size(); ++i) { heads_[0].source->inputs[i] = args[i].heads_[0]; } + for (size_t i = args.size(); i < req_args.size(); ++i) { + heads_[0].source->inputs[i] = DataEntry( + std::make_shared(nullptr, DefaultVarName(name, req_args[i])), 0); + } } else { // find all the place holders size_t arg_counter = 0; @@ -325,15 +339,8 @@ void Symbol::Compose(const std::unordered_map& kwargs, heads_[0].source->inputs[i] = iter->second.heads_[0]; ++nmatched; } else { - // create a variable node - // TODO(bing): think of naming convention - if (name.length() == 0) { - heads_[0].source->inputs[i] = DataEntry( - std::make_shared(nullptr, req_args[i]), 0); - } else { - heads_[0].source->inputs[i] = DataEntry( - std::make_shared(nullptr, name + '_' + req_args[i]), 0); - } + heads_[0].source->inputs[i] = DataEntry( + std::make_shared(nullptr, DefaultVarName(name, req_args[i])), 0); } } // if things goes wrong recover the old state diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index 350adfde274e..40304187a5fa 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -9,12 +9,12 @@ # symbol net batch_size = 100 data = mx.symbol.Variable('data') -fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) -act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") -fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) -act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") -fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) -softmax = mx.symbol.Softmax(data = fc3, name = 'sm') +fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) +act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") +fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) +act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") +fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) +softmax = mx.symbol.Softmax(fc3, name = 'sm') num_round = 4 prefix = './mlp'