Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Gluon] [Fix] Fix HybridBlock when hybridize is not called #16465

Merged
merged 1 commit into from
Oct 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 74 additions & 36 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,53 @@ def __exit__(self, ptype, value, trace):
_BlockScope._current.value = self._old_scope


def _gather_type_ctx_info(args):
"""Analyze the elements inside the nested args object and find:
- If there exists ndarray
- If there exists symbol
- All contexts appearing in args

Parameters
----------
args : list or NDArray or Symbol
Could be a nested architecture.

Returns
-------
has_symbol : bool
Whether the elements in args contains symbols
has_ndarray : bool
Whether the elements in args contains ndarrays
ctx_set : set of mxnet.context.Context
Contains all possible contexts of the inner ndarrays in args. Can be empty if there is no
ndarray inside args.
first_ctx : mxnet.context.Context or None
Context of the first appeared NDArray (for backward-compatibility)
"""
if isinstance(args, NDArray):
return False, True, {args.context}, args.context
elif isinstance(args, Symbol):
return True, False, set(), None
elif isinstance(args, (list, tuple)):
has_symbol = False
has_ndarray = False
ctx_set = set()
first_ctx = None
for ele in args:
ele_has_sym, ele_has_nd, ele_ctx_set, ele_first_ctx =\
_gather_type_ctx_info(ele)
has_symbol = has_symbol or ele_has_sym
has_ndarray = has_ndarray or ele_has_nd
if first_ctx is None and ele_first_ctx is not None:
first_ctx = ele_first_ctx
ctx_set = ctx_set | ele_ctx_set
if has_symbol and has_ndarray:
break
return has_symbol, has_ndarray, ctx_set, first_ctx
else:
return False, False, set(), None


def _flatten(args, inout_str):
"""Parse the arguments into a flattened list + an additional format array.
The format array stores the structure of the original arguments to help reconstruct the inputs.
Expand Down Expand Up @@ -120,9 +167,11 @@ def _flatten(args, inout_str):
if args is None:
return [None], int(-1)

assert isinstance(args, (list, tuple)), \
"HybridBlock {} must be (nested) list of Symbol or NDArray, " \
"but got {} of type {}".format(inout_str, str(args), str(type(args)))
if not isinstance(args, (list, tuple)):
raise ValueError("When hybridized, the input of HybridBlock {}"
" must be (nested) list of Symbol"
" or NDArray, "
"but got {} of type {}".format(inout_str, str(args), str(type(args))))
flat = []
fmts = []
for i in args:
Expand Down Expand Up @@ -164,9 +213,10 @@ def _merger(args, fmt):
else:
return args[:fmt], args[fmt:]

assert isinstance(args, (list, tuple)), \
"HybridBlock output must be (nested) list of Symbol or NDArray, " \
"but got {} of type {}".format(args, type(args))
if not isinstance(args, (list, tuple)):
raise ValueError("When hybridized, the output of HybridBlock must be (nested)"
" list of Symbol or NDArray, "
"but got {} of type {}".format(args, type(args)))
ret = []
for i in fmt:
res, args = _merger(args, i)
Expand Down Expand Up @@ -1054,38 +1104,26 @@ def register_op_hook(self, callback, monitor_all=False):
def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
flatten_args = _flatten([x] + list(args), 'inputs')[0]
is_ndarray = None
ctx = None
exist_sym_nd = False
for ele in flatten_args:
if isinstance(ele, NDArray):
if is_ndarray is False:
raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
' types for the input.\n'
'Received types are: {}.'
.format([type(ele) for ele in flatten_args]))
is_ndarray = True
exist_sym_nd = True
ctx = ele.context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line introduced in #16280 is not compatible with the previous context handling. Previously, always x.context is used as default context. https://github.com/apache/incubator-mxnet/pull/16280/files#diff-29da832c2145752f3906a2f71c7b63baL982
Does it matter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is chosen like this because x can be None now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it should be set to the first non-None argument, not the last?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, all ctxs are supposed to be the same. For example, we should not allow the mixing of cpu and gpu contexts. However, we currently allow to do so because we will need to mix cpu, cpu_pinned, and cpu_shared.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thus we should use the first non-None argument not to break backwards compatibility? cpu, cpu_pinned, cpu_shared are different contexts after all

Copy link
Member Author

@sxjscience sxjscience Oct 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leezu I think using the first or last non-None argument does not matter much here. Our goal is to make sure that we will finally pick a meaningful context for the parameters. In fact, the previous implementation has not checked whether the contexts of the arguments are valid.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the previous implementation hasn't enforced all contexts being equal, we shouldn't start picking a different array to determine the context. As you stated above, it's valid to use a mix of cpu, cpu_pinned, cpu_shared contexts.
For example, after your change, cpu_pinned or cpu_shared may be picked as default context instead of cpu if the user passed a cpu_pinned or cpu_shared as last argument. The extra overhead could cause a performance regression (all parameters will be made available under default context).
No need to risk this given there is no advantage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leezu It's also possible that, previously cpu_pinned is picked as the default argument and after the change, the correct cpu context is picked as the default. My point is we need to probably give special treatment of the cpu, cpu_pinned, cpu_shared. What's your opinion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leezu I agree that the backward-compatible issue is valid. Let me first make it to be backward-compatible. However, this does not fix the issue of the cpu, cpu_pinned, cpu_shared combination.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should get rid of choosing one array and using it's context as default context. For parameters, users should get the array via self.weight.data(ctx). For the time being I suggest not to break the behaviour, to avoid unintended consequences

elif isinstance(ele, Symbol):
if is_ndarray:
raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
' types for the input.\n'
'Received types are: {}.'
.format([type(ele) for ele in flatten_args]))
is_ndarray = False
exist_sym_nd = True
else:
assert ele is None, 'Only support None, NDArray and Symbol as the input'
if not exist_sym_nd:
raise ValueError('There must at least one NDArray or Symbol in the input, received')

if is_ndarray:
with ctx:
if self._active:
has_symbol, has_ndarray, ctx_set, first_ctx = _gather_type_ctx_info([x] + list(args))
if has_symbol and has_ndarray:
raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
' types for the input. Please check the type of the args.\n')
if not has_symbol and not has_ndarray:
raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.'
' Please check the type of the args.\n')
if has_ndarray:
ctx = first_ctx
if self._active:
if len(ctx_set) > 1:
raise ValueError('Find multiple contexts in the input, '
'After hybridized, the HybridBlock only supports one input '
'context. You can print the ele.context in the '
'input arguments to inspect their contexts. '
'Find all contexts = {}'.format(ctx_set))
with ctx:
return self._call_cached_op(x, *args)

with ctx:
try:
params = {k: v.data(ctx) for k, v in self._reg_params.items()}
except DeferredInitializationError:
Expand Down
15 changes: 15 additions & 0 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,21 @@ def test_bulking():
.format(fully_bulked_time - fastest_half_bulked_time, times_str)


@with_seed()
def test_hybridblock_mix_ctx_raise():
class FooHybrid(gluon.HybridBlock):
def hybrid_forward(self, F, a, b):
if isinstance(a, (list, tuple)):
a = sum(a)
if isinstance(b, (list, tuple)):
b = sum(b)
return a + b
foo_hybrid = FooHybrid()
foo_hybrid.hybridize()
assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,), ctx=mx.gpu()),
mx.nd.ones((10,), ctx=mx.cpu())))


if __name__ == '__main__':
import nose
nose.runmodule()
44 changes: 44 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,50 @@ def hybrid_forward(self, F, a, b):
assert_raises(ValueError, lambda: foo1(mx.nd.ones((10,)), mx.nd.ones((10,))))


@with_seed()
def test_hybrid_block_hybrid_no_hybrid():
class FooHybrid(gluon.HybridBlock):
def hybrid_forward(self, F, a, b):
if isinstance(a, (list, tuple)):
a = sum(a)
if isinstance(b, (list, tuple)):
b = sum(b)
return a + b

class Foo(gluon.Block):
def forward(self, a, b):
if isinstance(a, (list, tuple)):
a = sum(a)
if isinstance(b, (list, tuple)):
b = sum(b)
return a + b
# When hybridize is not called, HybridBlock acts the same as Block
foo_hybrid = FooHybrid()
foo = Foo()
for a, b in [(mx.nd.ones((10,)), 1),
(mx.nd.ones((20,)), 2),
([mx.nd.ones((10,)), mx.nd.ones((10,))],
[mx.nd.ones((10)), mx.nd.ones((10,)), mx.nd.ones((10,))]),
([mx.nd.ones((10,)), mx.nd.ones((10,))], 3)]:
hybrid_block_out = foo_hybrid(a, b)
block_out = foo(a, b)
assert_almost_equal(hybrid_block_out.asnumpy(), block_out.asnumpy())
# When hybridize is called, we need to make sure that the model raises for the unsupported cases
# 1. Scalar values in the input
# 2. No mixing of sym/ndarray
# 3. No mixing of cpu ndarray and gpu ndarray (Tested in gpu/test_gluon_gpu.py)
# 4. Allow mixing of cpu_pinned and cpu
foo_hybrid = FooHybrid()
foo_hybrid.hybridize()
assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,)), 1))
foo_hybrid = FooHybrid()
foo_hybrid.hybridize()
assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,)), mx.sym.var('a')))
foo_hybrid = FooHybrid()
foo_hybrid.hybridize()
assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,), ctx=mx.cpu(1)),
mx.nd.ones((10,), ctx=mx.cpu(2))))


@with_seed()
def check_layer_forward(layer, dshape):
Expand Down