Skip to content

Commit

Permalink
[Gluon] Support None argument in HybridBlock (apache#16280)
Browse files Browse the repository at this point in the history
* support none in hybridblock argument

support None as arguments in HybridBlock

Update block.py

fix

fix

Update test_gluon.py

Update test_gluon.py

* fix bug

* fix bug

* test case of default values
  • Loading branch information
sxjscience authored and sojiadeshina committed Sep 30, 2019
1 parent bc01a86 commit 66b06d6
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 43 deletions.
186 changes: 143 additions & 43 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .. util import is_np_array, np_shape, np_array



class _BlockScope(object):
"""Scope for collecting child `Block` s."""
_current = threading.local()
Expand Down Expand Up @@ -93,16 +94,35 @@ def __exit__(self, ptype, value, trace):


def _flatten(args, inout_str):
"""Parse the arguments into a flattened list + an additional format array.
The format array stores the structure of the original arguments to help reconstruct the inputs.
Parameters
----------
args : NDArray, Symbol, or (nested) list of Symbol or NDArray
We allow None inside the args.
inout_str : str
The name of the HybridBlock
Returns
-------
flat : list of Symbol or NDArray
The flatten version of the input args.
fmts : (nested) list of ints
Stores the format information of the original structured args.
"""
if isinstance(args, NDArray):
return [args], int(0)
if isinstance(args, Symbol):
length = len(args.list_outputs())
length = length if length > 1 else 0
return [args], int(length)
if args is None:
return [None], int(-1)

assert isinstance(args, (list, tuple)), \
"HybridBlock %s must be (nested) list of Symbol or NDArray, " \
"but got %s of type %s"%(inout_str, str(args), str(type(args)))
"HybridBlock {} must be (nested) list of Symbol or NDArray, " \
"but got {} of type {}".format(inout_str, str(args), str(type(args)))
flat = []
fmts = []
for i in args:
Expand All @@ -113,19 +133,46 @@ def _flatten(args, inout_str):


def _regroup(args, fmt):
if isinstance(fmt, int):
if fmt == 0:
return args[0], args[1:]
return args[:fmt], args[fmt:]
"""Reconstruct the structured arguments based on the flattened version.
assert isinstance(args, (list, tuple)), \
"HybridBlock output must be (nested) list of Symbol or NDArray, " \
"but got %s of type %s"%(str(args), str(type(args)))
ret = []
for i in fmt:
res, args = _regroup(args, i)
ret.append(res)
return ret, args
Parameters
----------
args : NDArray, Symbol, or (nested) list of Symbol or NDArray
We allow None inside the args.
fmt : (nested) list of ints
Stores the format information of the original structured args.
Returns
-------
ret : NDArray, Symbol, or (nested) list of Symbol or NDArray
"""
def _merger(args, fmt):
"""Recursive call to merge the arguments"""
if isinstance(fmt, int):
if fmt < -1:
raise ValueError("Unsupported encoded format {}.".format(fmt))
if fmt == 0:
return args[0], args[1:]
if fmt == -1:
if args[0] is not None:
raise ValueError('We do not support passing types that are not None'
' when the initial HybridBlock has received NoneType and'
' has been hybridized.'
' Received arg = {}, fmt = {}.'.format(args[0], fmt))
return None, args[1:]
else:
return args[:fmt], args[fmt:]

assert isinstance(args, (list, tuple)), \
"HybridBlock output must be (nested) list of Symbol or NDArray, " \
"but got {} of type {}".format(args, type(args))
ret = []
for i in fmt:
res, args = _merger(args, i)
ret.append(res)
return ret, args
return _merger(args, fmt)[0]


class Block(object):
Expand Down Expand Up @@ -778,37 +825,48 @@ def __setattr__(self, name, value):

def _get_graph(self, *args):
if not self._cached_graph:
args, self._in_format = _flatten(args, "input")
if len(args) > 1:
inputs = [symbol.var('data%d' % i).as_np_ndarray()
if isinstance(args[i], _mx_np.ndarray)
else symbol.var('data%d' % i) for i in range(len(args))]
else:
inputs = [symbol.var('data').as_np_ndarray()
if isinstance(args[0], _mx_np.ndarray)
else symbol.var('data')]
grouped_inputs = _regroup(inputs, self._in_format)[0]

flatten_args, self._in_format = _flatten(args, "input")
flatten_inputs = []
symbol_inputs = []
cnt = 0
real_arg_num = sum([ele is not None for ele in flatten_args])
if real_arg_num == 0:
raise ValueError('All args are None and we do not support such a case.'
' Received args={}'.format(args))
for arg in flatten_args:
if arg is not None:
if real_arg_num > 1:
arg_sym = symbol.var('data{}'.format(cnt))
else:
arg_sym = symbol.var('data')
if isinstance(arg, _mx_np.ndarray):
arg_sym = arg_sym.as_np_ndarray()
cnt += 1
flatten_inputs.append(arg_sym)
symbol_inputs.append(arg_sym)
else:
flatten_inputs.append(None)
grouped_inputs = _regroup(flatten_inputs, self._in_format)
params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter
out, self._out_format = _flatten(out, "output")

self._cached_graph = inputs, symbol.Group(out, _check_same_symbol_type(out))
self._cached_graph = symbol_inputs, symbol.Group(out, _check_same_symbol_type(out))

return self._cached_graph

def _build_cache(self, *args):
data, out = self._get_graph(*args)
data_names = {data.name : i for i, data in enumerate(data)}
data_names = {data.name: i for i, data in enumerate(data)}
params = self.collect_params()
input_names = out.list_inputs()

param_names = set(params.keys())
expected_names = set(input_names)
for name in expected_names:
assert name in param_names or name in data_names, \
"Unknown input to HybridBlock: %s"%name
"Unknown input to HybridBlock: %s" %name

used_data_names = [i for i in data_names if i in expected_names]
if len(used_data_names) != len(data_names):
Expand Down Expand Up @@ -856,23 +914,40 @@ def _call_cached_op(self, *args):
" and may not work correctly")

args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
if fmt != self._in_format:
# Do not raise in the case that the fmt or stored_fmt ends with None and
# We are relying on the default values.
if len(self._in_format) > len(fmt):
valid = all([self._in_format[i] == -1
for i in range(len(fmt), len(self._in_format))])
valid = valid and (fmt == self._in_format[:len(fmt)])
elif len(self._in_format) < len(fmt):
valid = all([fmt[i] == -1
for i in range(len(self._in_format), len(fmt))])
valid = valid and (fmt[:len(self._in_format)] == self._in_format)
else:
valid = False
if not valid:
raise ValueError("The argument structure of HybridBlock does not match"
" the cached version. Stored format = {}, input format = {}"
.format(fmt, self._in_format))
args_without_none = [ele for ele in args if ele is not None]
try:
cargs = [args[i] if is_arg else i.data()
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
except DeferredInitializationError:
self._deferred_infer_shape(*args)
cargs = []
for is_arg, i in self._cached_op_args:
if is_arg:
cargs.append(args[i])
cargs.append(args_without_none[i])
else:
i._finish_deferred_init()
cargs.append(i.data())
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)[0]
return _regroup(out, self._out_format)

def _clear_cached_op(self):
self._cached_graph = ()
Expand Down Expand Up @@ -906,9 +981,10 @@ def _infer_attrs(self, infer_fn, attr, *args):
"""Generic infer attributes."""
inputs, out = self._get_graph(*args)
args, _ = _flatten(args, "input")
args_without_none = [ele for ele in args if ele is not None]
with warnings.catch_warnings(record=True) as w:
arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
**{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
**{i.name: getattr(j, attr) for i, j in zip(inputs, args_without_none)})
if arg_attrs is None:
raise ValueError(w[0].message)
sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)}
Expand Down Expand Up @@ -978,24 +1054,48 @@ def register_op_hook(self, callback, monitor_all=False):
def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
flatten_args = _flatten([x] + list(args), 'inputs')[0]
is_ndarray = None
ctx = None
exist_sym_nd = False
for ele in flatten_args:
if isinstance(ele, NDArray):
if is_ndarray is False:
raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
' types for the input.\n'
'Received types are: {}.'
.format([type(ele) for ele in flatten_args]))
is_ndarray = True
exist_sym_nd = True
ctx = ele.context
elif isinstance(ele, Symbol):
if is_ndarray:
raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
' types for the input.\n'
'Received types are: {}.'
.format([type(ele) for ele in flatten_args]))
is_ndarray = False
exist_sym_nd = True
else:
assert ele is None, 'Only support None, NDArray and Symbol as the input'
if not exist_sym_nd:
raise ValueError('There must at least one NDArray or Symbol in the input, received')

if is_ndarray:
with ctx:
if self._active:
return self._call_cached_op(x, *args)

try:
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
params = {k: v.data(ctx) for k, v in self._reg_params.items()}
except DeferredInitializationError:
self._deferred_infer_shape(x, *args)
for _, i in self.params.items():
i._finish_deferred_init()
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
for _, v in self.params.items():
v._finish_deferred_init()
params = {k: v.data(ctx) for k, v in self._reg_params.items()}

return self.hybrid_forward(ndarray, x, *args, **params)

assert isinstance(x, Symbol), \
"HybridBlock requires the first argument to forward be either " \
"Symbol or NDArray, but got %s"%type(x)
params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
return self.hybrid_forward(symbol, x, *args, **params)
Expand Down Expand Up @@ -1173,7 +1273,7 @@ def forward(self, x, *args):
assert in_fmt == self._in_format, "Invalid input format"
ret = copy.copy(self._cached_graph[1])
ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], args)})
return _regroup(list(ret), self._out_format)[0]
return _regroup(list(ret), self._out_format)

def _clear_cached_op(self):
tmp = self._cached_graph
Expand Down
88 changes: 88 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,94 @@ def test_sparse_hybrid_block():
# an exception is expected when forwarding a HybridBlock w/ sparse param
y = net(x)

@with_seed()
def test_hybrid_block_none_args():
class Foo(gluon.HybridBlock):
def hybrid_forward(self, F, a, b):
if a is None and b is not None:
return b
elif b is None and a is not None:
return a
elif a is not None and b is not None:
return a + b
else:
raise NotImplementedError

class FooDefault(gluon.HybridBlock):
def hybrid_forward(self, F, a, b=None):
if a is None and b is not None:
return b
elif b is None and a is not None:
return a
elif a is not None and b is not None:
return a + b
else:
raise NotImplementedError


class FooNested(gluon.HybridBlock):
def __init__(self, prefix=None, params=None):
super(FooNested, self).__init__(prefix=prefix, params=params)
self.f1 = Foo(prefix='foo1')
self.f2 = Foo(prefix='foo2')
self.f3 = Foo(prefix='foo3')

def hybrid_forward(self, F, a, b):
data = self.f1(a, b)
data = self.f2(a, data)
data = self.f3(data, b)
return data

for arg_inputs in [(None, mx.nd.ones((10,))),
(mx.nd.ones((10,)), mx.nd.ones((10,))),
(mx.nd.ones((10,)), None)]:
foo1 = FooNested(prefix='foo_nested_hybridized')
foo1.hybridize()
foo2 = FooNested(prefix='foo_nested_nohybrid')
for _ in range(2): # Loop for 2 times to trigger forwarding of the cached version
out1 = foo1(*arg_inputs)
out2 = foo2(*arg_inputs)
if isinstance(out1, tuple):
for lhs, rhs in zip(out1, out2):
assert_almost_equal(lhs.asnumpy(), rhs.asnumpy())
else:
assert_almost_equal(out1.asnumpy(), out2.asnumpy())
for do_hybridize in [True, False]:
foo = FooNested()
if do_hybridize:
foo.hybridize()
assert_raises(ValueError, foo, None, None)

# Make sure the ValueError is correctly raised
foo = FooNested()
foo.hybridize()
foo(None, mx.nd.ones((10,))) # Pass for the first time to initialize the cached op
assert_raises(ValueError, lambda: foo(mx.nd.ones((10,)), mx.nd.ones((10,))))
foo = FooNested()
assert_raises(ValueError, lambda: foo(mx.nd.ones((10,)), mx.sym.var('a')))
foo = FooNested()
assert_raises(ValueError, lambda: foo(mx.sym.var('a'), mx.nd.ones((10,))))

# Test the case of the default values
foo1 = FooDefault()
foo1.hybridize()
foo2 = FooDefault()
out1 = foo1(mx.nd.ones((10,)))
out2 = foo2(mx.nd.ones((10,)))
out3 = foo1(mx.nd.ones((10,)), None)
out4 = foo2(mx.nd.ones((10,)), None)
assert_almost_equal(out1.asnumpy(), out2.asnumpy())
assert_almost_equal(out1.asnumpy(), out3.asnumpy())
assert_almost_equal(out1.asnumpy(), out4.asnumpy())
foo1 = FooDefault()
foo1.hybridize()
out1 = foo1(mx.nd.ones((10,)), None)
out2 = foo1(mx.nd.ones((10,)))
assert_almost_equal(out1.asnumpy(), out2.asnumpy())
assert_raises(ValueError, lambda: foo1(mx.nd.ones((10,)), mx.nd.ones((10,))))



@with_seed()
def check_layer_forward(layer, dshape):
print("checking layer {}\nshape: {}.".format(layer, dshape))
Expand Down

0 comments on commit 66b06d6

Please sign in to comment.