Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Allow operators with multiple outputs in get_atomic_symbol (#15740)
Browse files Browse the repository at this point in the history
* Allow operators with multiple outputs in get_atomic_symbol

* Added unittest
  • Loading branch information
ptrendx committed Aug 13, 2019
1 parent 11ce2a2 commit 24a5cf0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1158,10 +1158,15 @@ int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_h
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *source = static_cast<nnvm::Symbol *>(sym_handle);
CHECK_EQ(source->outputs.size(), 1U)
<< "Generating atomic symbol from other symbol only works for nongrouped symbol.";
const auto &node = source->outputs[0];
const auto *op = node.node->op();
CHECK_GE(source->outputs.size(), 1) << "Input symbol does not have outputs.";
const auto &node = source->outputs[0].node;
for (const auto &other_node : source->outputs) {
if (node.get() != other_node.node.get()) {
LOG(FATAL)
<< "Generating atomic symbol from other symbol only works for nongrouped symbol.";
}
}
const auto *op = node->op();
const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow);
*s = nnvm::Symbol::CreateFunctor(op, attrs);
*ret_sym_handle = s;
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,15 @@ def test_children_same_name():
for c in b.get_children():
pass

def test_gen_atomic_symbol_multiple_outputs():
data=mx.sym.Variable('data')
p = mx.sym.Variable('param')
h0 = mx.sym.Variable('h0')
h1 = mx.sym.Variable('h1')
s = mx.sym.RNN(data, p, h0, h1, state_size=10, num_layers=2,
bidirectional=True, state_outputs=True, mode='lstm')
atomic_sym = s._gen_atomic_symbol()

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 24a5cf0

Please sign in to comment.