diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 7047364966af..2f3ed91cb5b7 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1024,6 +1024,14 @@ def imports(symbol_file, input_names, param_file=None, ctx=None): ret.collect_params().load(param_file, ctx=ctx) return ret + def __repr__(self): + s = '{name}(\n{modstr}\n)' + modstr = '\n'.join(['{block} : {numinputs} -> {numoutputs}'.format(block=self._cached_graph[1], + numinputs=len(self._cached_graph[0]), + numoutputs=len(self._cached_graph[1]. + list_outputs()))]) + return s.format(name=self.__class__.__name__, + modstr=modstr) def __init__(self, outputs, inputs, params=None): super(SymbolBlock, self).__init__(prefix=None, params=None) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 34380dc00314..6af7a5f948e2 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -882,8 +882,13 @@ def test_import(): net2 = gluon.SymbolBlock.imports( 'net1-symbol.json', ['data'], 'net1-0001.params', ctx) out2 = net2(data) + lines = str(net2).splitlines() assert_almost_equal(out1.asnumpy(), out2.asnumpy()) + assert lines[0] == 'SymbolBlock(' + assert lines[1] + assert lines[2] == ')' + @with_seed() def test_hybrid_stale_cache():