diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index fc08b4c6bd32..09dd51554fe6 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -39,6 +39,7 @@ from .. util import is_np_array, np_shape, np_array + class _BlockScope(object): """Scope for collecting child `Block` s.""" _current = threading.local() @@ -93,16 +94,35 @@ def __exit__(self, ptype, value, trace): def _flatten(args, inout_str): + """Parse the arguments into a flattened list + an additional format array. + The format array stores the structure of the original arguments to help reconstruct the inputs. + + Parameters + ---------- + args : NDArray, Symbol, or (nested) list of Symbol or NDArray + We allow None inside the args. + inout_str : str + The name of the HybridBlock + + Returns + ------- + flat : list of Symbol or NDArray + The flatten version of the input args. + fmts : (nested) list of ints + Stores the format information of the original structured args. + """ if isinstance(args, NDArray): return [args], int(0) if isinstance(args, Symbol): length = len(args.list_outputs()) length = length if length > 1 else 0 return [args], int(length) + if args is None: + return [None], int(-1) assert isinstance(args, (list, tuple)), \ - "HybridBlock %s must be (nested) list of Symbol or NDArray, " \ - "but got %s of type %s"%(inout_str, str(args), str(type(args))) + "HybridBlock {} must be (nested) list of Symbol or NDArray, " \ + "but got {} of type {}".format(inout_str, str(args), str(type(args))) flat = [] fmts = [] for i in args: @@ -113,19 +133,46 @@ def _flatten(args, inout_str): def _regroup(args, fmt): - if isinstance(fmt, int): - if fmt == 0: - return args[0], args[1:] - return args[:fmt], args[fmt:] + """Reconstruct the structured arguments based on the flattened version. - assert isinstance(args, (list, tuple)), \ - "HybridBlock output must be (nested) list of Symbol or NDArray, " \ - "but got %s of type %s"%(str(args), str(type(args))) - ret = [] - for i in fmt: - res, args = _regroup(args, i) - ret.append(res) - return ret, args + Parameters + ---------- + args : NDArray, Symbol, or (nested) list of Symbol or NDArray + We allow None inside the args. + fmt : (nested) list of ints + Stores the format information of the original structured args. + + Returns + ------- + ret : NDArray, Symbol, or (nested) list of Symbol or NDArray + + """ + def _merger(args, fmt): + """Recursive call to merge the arguments""" + if isinstance(fmt, int): + if fmt < -1: + raise ValueError("Unsupported encoded format {}.".format(fmt)) + if fmt == 0: + return args[0], args[1:] + if fmt == -1: + if args[0] is not None: + raise ValueError('We do not support passing types that are not None' + ' when the initial HybridBlock has received NoneType and' + ' has been hybridized.' + ' Received arg = {}, fmt = {}.'.format(args[0], fmt)) + return None, args[1:] + else: + return args[:fmt], args[fmt:] + + assert isinstance(args, (list, tuple)), \ + "HybridBlock output must be (nested) list of Symbol or NDArray, " \ + "but got {} of type {}".format(args, type(args)) + ret = [] + for i in fmt: + res, args = _merger(args, i) + ret.append(res) + return ret, args + return _merger(args, fmt)[0] class Block(object): @@ -778,29 +825,40 @@ def __setattr__(self, name, value): def _get_graph(self, *args): if not self._cached_graph: - args, self._in_format = _flatten(args, "input") - if len(args) > 1: - inputs = [symbol.var('data%d' % i).as_np_ndarray() - if isinstance(args[i], _mx_np.ndarray) - else symbol.var('data%d' % i) for i in range(len(args))] - else: - inputs = [symbol.var('data').as_np_ndarray() - if isinstance(args[0], _mx_np.ndarray) - else symbol.var('data')] - grouped_inputs = _regroup(inputs, self._in_format)[0] - + flatten_args, self._in_format = _flatten(args, "input") + flatten_inputs = [] + symbol_inputs = [] + cnt = 0 + real_arg_num = sum([ele is not None for ele in flatten_args]) + if real_arg_num == 0: + raise ValueError('All args are None and we do not support such a case.' + ' Received args={}'.format(args)) + for arg in flatten_args: + if arg is not None: + if real_arg_num > 1: + arg_sym = symbol.var('data{}'.format(cnt)) + else: + arg_sym = symbol.var('data') + if isinstance(arg, _mx_np.ndarray): + arg_sym = arg_sym.as_np_ndarray() + cnt += 1 + flatten_inputs.append(arg_sym) + symbol_inputs.append(arg_sym) + else: + flatten_inputs.append(None) + grouped_inputs = _regroup(flatten_inputs, self._in_format) params = {i: j.var() for i, j in self._reg_params.items()} with self.name_scope(): out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter out, self._out_format = _flatten(out, "output") - self._cached_graph = inputs, symbol.Group(out, _check_same_symbol_type(out)) + self._cached_graph = symbol_inputs, symbol.Group(out, _check_same_symbol_type(out)) return self._cached_graph def _build_cache(self, *args): data, out = self._get_graph(*args) - data_names = {data.name : i for i, data in enumerate(data)} + data_names = {data.name: i for i, data in enumerate(data)} params = self.collect_params() input_names = out.list_inputs() @@ -808,7 +866,7 @@ def _build_cache(self, *args): expected_names = set(input_names) for name in expected_names: assert name in param_names or name in data_names, \ - "Unknown input to HybridBlock: %s"%name + "Unknown input to HybridBlock: %s" %name used_data_names = [i for i in data_names if i in expected_names] if len(used_data_names) != len(data_names): @@ -856,23 +914,40 @@ def _call_cached_op(self, *args): " and may not work correctly") args, fmt = _flatten(args, "input") - assert fmt == self._in_format, "Invalid input format" + if fmt != self._in_format: + # Do not raise in the case that the fmt or stored_fmt ends with None and + # We are relying on the default values. + if len(self._in_format) > len(fmt): + valid = all([self._in_format[i] == -1 + for i in range(len(fmt), len(self._in_format))]) + valid = valid and (fmt == self._in_format[:len(fmt)]) + elif len(self._in_format) < len(fmt): + valid = all([fmt[i] == -1 + for i in range(len(self._in_format), len(fmt))]) + valid = valid and (fmt[:len(self._in_format)] == self._in_format) + else: + valid = False + if not valid: + raise ValueError("The argument structure of HybridBlock does not match" + " the cached version. Stored format = {}, input format = {}" + .format(fmt, self._in_format)) + args_without_none = [ele for ele in args if ele is not None] try: - cargs = [args[i] if is_arg else i.data() + cargs = [args_without_none[i] if is_arg else i.data() for is_arg, i in self._cached_op_args] except DeferredInitializationError: self._deferred_infer_shape(*args) cargs = [] for is_arg, i in self._cached_op_args: if is_arg: - cargs.append(args[i]) + cargs.append(args_without_none[i]) else: i._finish_deferred_init() cargs.append(i.data()) out = self._cached_op(*cargs) if isinstance(out, NDArray): out = [out] - return _regroup(out, self._out_format)[0] + return _regroup(out, self._out_format) def _clear_cached_op(self): self._cached_graph = () @@ -906,9 +981,10 @@ def _infer_attrs(self, infer_fn, attr, *args): """Generic infer attributes.""" inputs, out = self._get_graph(*args) args, _ = _flatten(args, "input") + args_without_none = [ele for ele in args if ele is not None] with warnings.catch_warnings(record=True) as w: arg_attrs, _, aux_attrs = getattr(out, infer_fn)( - **{i.name: getattr(j, attr) for i, j in zip(inputs, args)}) + **{i.name: getattr(j, attr) for i, j in zip(inputs, args_without_none)}) if arg_attrs is None: raise ValueError(w[0].message) sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)} @@ -978,24 +1054,48 @@ def register_op_hook(self, callback, monitor_all=False): def forward(self, x, *args): """Defines the forward computation. Arguments can be either :py:class:`NDArray` or :py:class:`Symbol`.""" - if isinstance(x, NDArray): - with x.context as ctx: + flatten_args = _flatten([x] + list(args), 'inputs')[0] + is_ndarray = None + ctx = None + exist_sym_nd = False + for ele in flatten_args: + if isinstance(ele, NDArray): + if is_ndarray is False: + raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols' + ' types for the input.\n' + 'Received types are: {}.' + .format([type(ele) for ele in flatten_args])) + is_ndarray = True + exist_sym_nd = True + ctx = ele.context + elif isinstance(ele, Symbol): + if is_ndarray: + raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols' + ' types for the input.\n' + 'Received types are: {}.' + .format([type(ele) for ele in flatten_args])) + is_ndarray = False + exist_sym_nd = True + else: + assert ele is None, 'Only support None, NDArray and Symbol as the input' + if not exist_sym_nd: + raise ValueError('There must at least one NDArray or Symbol in the input, received') + + if is_ndarray: + with ctx: if self._active: return self._call_cached_op(x, *args) try: - params = {i: j.data(ctx) for i, j in self._reg_params.items()} + params = {k: v.data(ctx) for k, v in self._reg_params.items()} except DeferredInitializationError: self._deferred_infer_shape(x, *args) - for _, i in self.params.items(): - i._finish_deferred_init() - params = {i: j.data(ctx) for i, j in self._reg_params.items()} + for _, v in self.params.items(): + v._finish_deferred_init() + params = {k: v.data(ctx) for k, v in self._reg_params.items()} return self.hybrid_forward(ndarray, x, *args, **params) - assert isinstance(x, Symbol), \ - "HybridBlock requires the first argument to forward be either " \ - "Symbol or NDArray, but got %s"%type(x) params = {i: j.var() for i, j in self._reg_params.items()} with self.name_scope(): return self.hybrid_forward(symbol, x, *args, **params) @@ -1173,7 +1273,7 @@ def forward(self, x, *args): assert in_fmt == self._in_format, "Invalid input format" ret = copy.copy(self._cached_graph[1]) ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], args)}) - return _regroup(list(ret), self._out_format)[0] + return _regroup(list(ret), self._out_format) def _clear_cached_op(self): tmp = self._cached_graph diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index ae2d62451a09..380ce762a9f7 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -441,6 +441,94 @@ def test_sparse_hybrid_block(): # an exception is expected when forwarding a HybridBlock w/ sparse param y = net(x) +@with_seed() +def test_hybrid_block_none_args(): + class Foo(gluon.HybridBlock): + def hybrid_forward(self, F, a, b): + if a is None and b is not None: + return b + elif b is None and a is not None: + return a + elif a is not None and b is not None: + return a + b + else: + raise NotImplementedError + + class FooDefault(gluon.HybridBlock): + def hybrid_forward(self, F, a, b=None): + if a is None and b is not None: + return b + elif b is None and a is not None: + return a + elif a is not None and b is not None: + return a + b + else: + raise NotImplementedError + + + class FooNested(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(FooNested, self).__init__(prefix=prefix, params=params) + self.f1 = Foo(prefix='foo1') + self.f2 = Foo(prefix='foo2') + self.f3 = Foo(prefix='foo3') + + def hybrid_forward(self, F, a, b): + data = self.f1(a, b) + data = self.f2(a, data) + data = self.f3(data, b) + return data + + for arg_inputs in [(None, mx.nd.ones((10,))), + (mx.nd.ones((10,)), mx.nd.ones((10,))), + (mx.nd.ones((10,)), None)]: + foo1 = FooNested(prefix='foo_nested_hybridized') + foo1.hybridize() + foo2 = FooNested(prefix='foo_nested_nohybrid') + for _ in range(2): # Loop for 2 times to trigger forwarding of the cached version + out1 = foo1(*arg_inputs) + out2 = foo2(*arg_inputs) + if isinstance(out1, tuple): + for lhs, rhs in zip(out1, out2): + assert_almost_equal(lhs.asnumpy(), rhs.asnumpy()) + else: + assert_almost_equal(out1.asnumpy(), out2.asnumpy()) + for do_hybridize in [True, False]: + foo = FooNested() + if do_hybridize: + foo.hybridize() + assert_raises(ValueError, foo, None, None) + + # Make sure the ValueError is correctly raised + foo = FooNested() + foo.hybridize() + foo(None, mx.nd.ones((10,))) # Pass for the first time to initialize the cached op + assert_raises(ValueError, lambda: foo(mx.nd.ones((10,)), mx.nd.ones((10,)))) + foo = FooNested() + assert_raises(ValueError, lambda: foo(mx.nd.ones((10,)), mx.sym.var('a'))) + foo = FooNested() + assert_raises(ValueError, lambda: foo(mx.sym.var('a'), mx.nd.ones((10,)))) + + # Test the case of the default values + foo1 = FooDefault() + foo1.hybridize() + foo2 = FooDefault() + out1 = foo1(mx.nd.ones((10,))) + out2 = foo2(mx.nd.ones((10,))) + out3 = foo1(mx.nd.ones((10,)), None) + out4 = foo2(mx.nd.ones((10,)), None) + assert_almost_equal(out1.asnumpy(), out2.asnumpy()) + assert_almost_equal(out1.asnumpy(), out3.asnumpy()) + assert_almost_equal(out1.asnumpy(), out4.asnumpy()) + foo1 = FooDefault() + foo1.hybridize() + out1 = foo1(mx.nd.ones((10,)), None) + out2 = foo1(mx.nd.ones((10,))) + assert_almost_equal(out1.asnumpy(), out2.asnumpy()) + assert_raises(ValueError, lambda: foo1(mx.nd.ones((10,)), mx.nd.ones((10,)))) + + + @with_seed() def check_layer_forward(layer, dshape): print("checking layer {}\nshape: {}.".format(layer, dshape))