Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Gluon] Support None argument in HybridBlock #16280

Merged
merged 4 commits into from
Sep 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 143 additions & 43 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .. util import is_np_array, np_shape, np_array



class _BlockScope(object):
"""Scope for collecting child `Block` s."""
_current = threading.local()
Expand Down Expand Up @@ -93,16 +94,35 @@ def __exit__(self, ptype, value, trace):


def _flatten(args, inout_str):
"""Parse the arguments into a flattened list + an additional format array.
The format array stores the structure of the original arguments to help reconstruct the inputs.

Parameters
----------
args : NDArray, Symbol, or (nested) list of Symbol or NDArray
We allow None inside the args.
inout_str : str
The name of the HybridBlock

Returns
-------
flat : list of Symbol or NDArray
The flatten version of the input args.
fmts : (nested) list of ints
Stores the format information of the original structured args.
"""
if isinstance(args, NDArray):
return [args], int(0)
if isinstance(args, Symbol):
length = len(args.list_outputs())
length = length if length > 1 else 0
return [args], int(length)
if args is None:
return [None], int(-1)

assert isinstance(args, (list, tuple)), \
"HybridBlock %s must be (nested) list of Symbol or NDArray, " \
"but got %s of type %s"%(inout_str, str(args), str(type(args)))
"HybridBlock {} must be (nested) list of Symbol or NDArray, " \
"but got {} of type {}".format(inout_str, str(args), str(type(args)))
flat = []
fmts = []
for i in args:
Expand All @@ -113,19 +133,46 @@ def _flatten(args, inout_str):


def _regroup(args, fmt):
if isinstance(fmt, int):
if fmt == 0:
return args[0], args[1:]
return args[:fmt], args[fmt:]
"""Reconstruct the structured arguments based on the flattened version.

assert isinstance(args, (list, tuple)), \
"HybridBlock output must be (nested) list of Symbol or NDArray, " \
"but got %s of type %s"%(str(args), str(type(args)))
ret = []
for i in fmt:
res, args = _regroup(args, i)
ret.append(res)
return ret, args
Parameters
----------
args : NDArray, Symbol, or (nested) list of Symbol or NDArray
We allow None inside the args.
fmt : (nested) list of ints
Stores the format information of the original structured args.

Returns
-------
ret : NDArray, Symbol, or (nested) list of Symbol or NDArray

"""
def _merger(args, fmt):
"""Recursive call to merge the arguments"""
if isinstance(fmt, int):
if fmt < -1:
raise ValueError("Unsupported encoded format {}.".format(fmt))
if fmt == 0:
return args[0], args[1:]
if fmt == -1:
if args[0] is not None:
raise ValueError('We do not support passing types that are not None'
' when the initial HybridBlock has received NoneType and'
' has been hybridized.'
' Received arg = {}, fmt = {}.'.format(args[0], fmt))
return None, args[1:]
else:
return args[:fmt], args[fmt:]

assert isinstance(args, (list, tuple)), \
"HybridBlock output must be (nested) list of Symbol or NDArray, " \
"but got {} of type {}".format(args, type(args))
ret = []
for i in fmt:
res, args = _merger(args, i)
ret.append(res)
return ret, args
return _merger(args, fmt)[0]


class Block(object):
Expand Down Expand Up @@ -778,37 +825,48 @@ def __setattr__(self, name, value):

def _get_graph(self, *args):
if not self._cached_graph:
args, self._in_format = _flatten(args, "input")
if len(args) > 1:
inputs = [symbol.var('data%d' % i).as_np_ndarray()
if isinstance(args[i], _mx_np.ndarray)
else symbol.var('data%d' % i) for i in range(len(args))]
else:
inputs = [symbol.var('data').as_np_ndarray()
if isinstance(args[0], _mx_np.ndarray)
else symbol.var('data')]
grouped_inputs = _regroup(inputs, self._in_format)[0]

flatten_args, self._in_format = _flatten(args, "input")
flatten_inputs = []
symbol_inputs = []
cnt = 0
real_arg_num = sum([ele is not None for ele in flatten_args])
if real_arg_num == 0:
raise ValueError('All args are None and we do not support such a case.'
' Received args={}'.format(args))
for arg in flatten_args:
if arg is not None:
if real_arg_num > 1:
arg_sym = symbol.var('data{}'.format(cnt))
else:
arg_sym = symbol.var('data')
if isinstance(arg, _mx_np.ndarray):
arg_sym = arg_sym.as_np_ndarray()
cnt += 1
flatten_inputs.append(arg_sym)
symbol_inputs.append(arg_sym)
else:
flatten_inputs.append(None)
grouped_inputs = _regroup(flatten_inputs, self._in_format)
params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter
out, self._out_format = _flatten(out, "output")

self._cached_graph = inputs, symbol.Group(out, _check_same_symbol_type(out))
self._cached_graph = symbol_inputs, symbol.Group(out, _check_same_symbol_type(out))

return self._cached_graph

def _build_cache(self, *args):
data, out = self._get_graph(*args)
data_names = {data.name : i for i, data in enumerate(data)}
data_names = {data.name: i for i, data in enumerate(data)}
params = self.collect_params()
input_names = out.list_inputs()

param_names = set(params.keys())
expected_names = set(input_names)
for name in expected_names:
assert name in param_names or name in data_names, \
"Unknown input to HybridBlock: %s"%name
"Unknown input to HybridBlock: %s" %name

used_data_names = [i for i in data_names if i in expected_names]
if len(used_data_names) != len(data_names):
Expand Down Expand Up @@ -856,23 +914,40 @@ def _call_cached_op(self, *args):
" and may not work correctly")

args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
if fmt != self._in_format:
# Do not raise in the case that the fmt or stored_fmt ends with None and
# We are relying on the default values.
if len(self._in_format) > len(fmt):
valid = all([self._in_format[i] == -1
for i in range(len(fmt), len(self._in_format))])
valid = valid and (fmt == self._in_format[:len(fmt)])
elif len(self._in_format) < len(fmt):
valid = all([fmt[i] == -1
for i in range(len(self._in_format), len(fmt))])
valid = valid and (fmt[:len(self._in_format)] == self._in_format)
else:
valid = False
if not valid:
raise ValueError("The argument structure of HybridBlock does not match"
" the cached version. Stored format = {}, input format = {}"
.format(fmt, self._in_format))
args_without_none = [ele for ele in args if ele is not None]
try:
cargs = [args[i] if is_arg else i.data()
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
except DeferredInitializationError:
self._deferred_infer_shape(*args)
cargs = []
for is_arg, i in self._cached_op_args:
if is_arg:
cargs.append(args[i])
cargs.append(args_without_none[i])
else:
i._finish_deferred_init()
cargs.append(i.data())
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)[0]
return _regroup(out, self._out_format)

def _clear_cached_op(self):
self._cached_graph = ()
Expand Down Expand Up @@ -906,9 +981,10 @@ def _infer_attrs(self, infer_fn, attr, *args):
"""Generic infer attributes."""
inputs, out = self._get_graph(*args)
args, _ = _flatten(args, "input")
args_without_none = [ele for ele in args if ele is not None]
with warnings.catch_warnings(record=True) as w:
arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
**{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
**{i.name: getattr(j, attr) for i, j in zip(inputs, args_without_none)})
if arg_attrs is None:
raise ValueError(w[0].message)
sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)}
Expand Down Expand Up @@ -978,24 +1054,48 @@ def register_op_hook(self, callback, monitor_all=False):
def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
flatten_args = _flatten([x] + list(args), 'inputs')[0]
is_ndarray = None
ctx = None
exist_sym_nd = False
for ele in flatten_args:
if isinstance(ele, NDArray):
if is_ndarray is False:
raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
' types for the input.\n'
'Received types are: {}.'
.format([type(ele) for ele in flatten_args]))
is_ndarray = True
exist_sym_nd = True
ctx = ele.context
elif isinstance(ele, Symbol):
if is_ndarray:
raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
' types for the input.\n'
'Received types are: {}.'
.format([type(ele) for ele in flatten_args]))
is_ndarray = False
exist_sym_nd = True
else:
assert ele is None, 'Only support None, NDArray and Symbol as the input'
if not exist_sym_nd:
raise ValueError('There must at least one NDArray or Symbol in the input, received')

if is_ndarray:
with ctx:
if self._active:
return self._call_cached_op(x, *args)

try:
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
params = {k: v.data(ctx) for k, v in self._reg_params.items()}
except DeferredInitializationError:
self._deferred_infer_shape(x, *args)
for _, i in self.params.items():
i._finish_deferred_init()
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
for _, v in self.params.items():
v._finish_deferred_init()
params = {k: v.data(ctx) for k, v in self._reg_params.items()}

return self.hybrid_forward(ndarray, x, *args, **params)

assert isinstance(x, Symbol), \
"HybridBlock requires the first argument to forward be either " \
"Symbol or NDArray, but got %s"%type(x)
params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
return self.hybrid_forward(symbol, x, *args, **params)
Expand Down Expand Up @@ -1173,7 +1273,7 @@ def forward(self, x, *args):
assert in_fmt == self._in_format, "Invalid input format"
ret = copy.copy(self._cached_graph[1])
ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], args)})
return _regroup(list(ret), self._out_format)[0]
return _regroup(list(ret), self._out_format)

def _clear_cached_op(self):
tmp = self._cached_graph
Expand Down
88 changes: 88 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,94 @@ def test_sparse_hybrid_block():
# an exception is expected when forwarding a HybridBlock w/ sparse param
y = net(x)

@with_seed()
def test_hybrid_block_none_args():
class Foo(gluon.HybridBlock):
def hybrid_forward(self, F, a, b):
if a is None and b is not None:
return b
elif b is None and a is not None:
return a
elif a is not None and b is not None:
return a + b
else:
raise NotImplementedError

class FooDefault(gluon.HybridBlock):
def hybrid_forward(self, F, a, b=None):
if a is None and b is not None:
return b
elif b is None and a is not None:
return a
elif a is not None and b is not None:
return a + b
else:
raise NotImplementedError


class FooNested(gluon.HybridBlock):
def __init__(self, prefix=None, params=None):
super(FooNested, self).__init__(prefix=prefix, params=params)
self.f1 = Foo(prefix='foo1')
self.f2 = Foo(prefix='foo2')
self.f3 = Foo(prefix='foo3')

def hybrid_forward(self, F, a, b):
data = self.f1(a, b)
data = self.f2(a, data)
data = self.f3(data, b)
return data

for arg_inputs in [(None, mx.nd.ones((10,))),
(mx.nd.ones((10,)), mx.nd.ones((10,))),
(mx.nd.ones((10,)), None)]:
foo1 = FooNested(prefix='foo_nested_hybridized')
foo1.hybridize()
foo2 = FooNested(prefix='foo_nested_nohybrid')
for _ in range(2): # Loop for 2 times to trigger forwarding of the cached version
out1 = foo1(*arg_inputs)
out2 = foo2(*arg_inputs)
if isinstance(out1, tuple):
for lhs, rhs in zip(out1, out2):
assert_almost_equal(lhs.asnumpy(), rhs.asnumpy())
else:
assert_almost_equal(out1.asnumpy(), out2.asnumpy())
for do_hybridize in [True, False]:
foo = FooNested()
if do_hybridize:
foo.hybridize()
assert_raises(ValueError, foo, None, None)

# Make sure the ValueError is correctly raised
foo = FooNested()
foo.hybridize()
foo(None, mx.nd.ones((10,))) # Pass for the first time to initialize the cached op
assert_raises(ValueError, lambda: foo(mx.nd.ones((10,)), mx.nd.ones((10,))))
foo = FooNested()
assert_raises(ValueError, lambda: foo(mx.nd.ones((10,)), mx.sym.var('a')))
foo = FooNested()
assert_raises(ValueError, lambda: foo(mx.sym.var('a'), mx.nd.ones((10,))))

# Test the case of the default values
foo1 = FooDefault()
foo1.hybridize()
foo2 = FooDefault()
out1 = foo1(mx.nd.ones((10,)))
out2 = foo2(mx.nd.ones((10,)))
out3 = foo1(mx.nd.ones((10,)), None)
out4 = foo2(mx.nd.ones((10,)), None)
assert_almost_equal(out1.asnumpy(), out2.asnumpy())
assert_almost_equal(out1.asnumpy(), out3.asnumpy())
assert_almost_equal(out1.asnumpy(), out4.asnumpy())
foo1 = FooDefault()
foo1.hybridize()
out1 = foo1(mx.nd.ones((10,)), None)
out2 = foo1(mx.nd.ones((10,)))
assert_almost_equal(out1.asnumpy(), out2.asnumpy())
assert_raises(ValueError, lambda: foo1(mx.nd.ones((10,)), mx.nd.ones((10,))))



@with_seed()
def check_layer_forward(layer, dshape):
print("checking layer {}\nshape: {}.".format(layer, dshape))
Expand Down