Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add repr for SymbolBlock (#14423)
Browse files Browse the repository at this point in the history
* Add repr for SymbolBlock

* Add a test

* Correct self.cached_graph

* Address review comments
  • Loading branch information
vandanavk authored and szha committed Mar 15, 2019
1 parent 43173f5 commit 226212b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,14 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
ret.collect_params().load(param_file, ctx=ctx)
return ret

def __repr__(self):
s = '{name}(\n{modstr}\n)'
modstr = '\n'.join(['{block} : {numinputs} -> {numoutputs}'.format(block=self._cached_graph[1],
numinputs=len(self._cached_graph[0]),
numoutputs=len(self._cached_graph[1].
list_outputs()))])
return s.format(name=self.__class__.__name__,
modstr=modstr)

def __init__(self, outputs, inputs, params=None):
super(SymbolBlock, self).__init__(prefix=None, params=None)
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,8 +882,13 @@ def test_import():
net2 = gluon.SymbolBlock.imports(
'net1-symbol.json', ['data'], 'net1-0001.params', ctx)
out2 = net2(data)
lines = str(net2).splitlines()

assert_almost_equal(out1.asnumpy(), out2.asnumpy())
assert lines[0] == 'SymbolBlock('
assert lines[1]
assert lines[2] == ')'


@with_seed()
def test_hybrid_stale_cache():
Expand Down

0 comments on commit 226212b

Please sign in to comment.