Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Aug 9, 2019
1 parent 2a3b553 commit 11834a2
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,15 @@ def test_children_same_name():
for c in b.get_children():
pass

def test_gen_atomic_symbol_multiple_outputs():
data=mx.sym.Variable('data')
p = mx.sym.Variable('param')
h0 = mx.sym.Variable('h0')
h1 = mx.sym.Variable('h1')
s = mx.sym.RNN(data, p, h0, h1, state_size=10, num_layers=2,
bidirectional=True, state_outputs=True, mode='lstm')
atomic_sym = s._gen_atomic_symbol()

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 11834a2

Please sign in to comment.