Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add repr for SymbolBlock #14423

Merged
merged 4 commits into from
Mar 15, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,11 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
ret.collect_params().load(param_file, ctx=ctx)
return ret

def __repr__(self):
s = '{name}(\n{modstr}\n)'
modstr = '\n'.join(['{block}'.format(block=self._cached_graph[-1])])
return s.format(name=self.__class__.__name__,
modstr=modstr)

def __init__(self, outputs, inputs, params=None):
super(SymbolBlock, self).__init__(prefix=None, params=None)
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,8 +882,13 @@ def test_import():
net2 = gluon.SymbolBlock.imports(
'net1-symbol.json', ['data'], 'net1-0001.params', ctx)
out2 = net2(data)
lines = str(net2).splitlines()

assert_almost_equal(out1.asnumpy(), out2.asnumpy())
assert lines[0] == 'SymbolBlock('
assert lines[1]
assert lines[2] == ')'


@with_seed()
def test_hybrid_stale_cache():
Expand Down