Skip to content

Commit

Permalink
Partitioning Gluon HybridBlocks (apache#15969)
Browse files Browse the repository at this point in the history
* stub for optimizing Gluon block

* Init commit for Gluon hybridblocks partition(sample test included)

* Added tests for Gluon and refactored tests

* call optimize_for in _build_cache

* Pass in 4 paras for gluon optimize_for

* Fixed auxiliary state issue, args issue and added 2 tests.

* Fixed auxiliary state issue, args issue and added 2 tests.

* changed parameter check

* refactored param init since needed for partitioning

* fixed whitespace

* fixed flattened args

* fixed sanity & updated tests

* fixed whitespace

* added context support in tests

* Fix python2 errors

* clean code remove cargs

* Add hybridblock hybridize() description

Co-authored-by: guanxinq <[email protected]>
  • Loading branch information
2 people authored and Ubuntu committed Feb 19, 2020
1 parent 2bb4f81 commit 5daf529
Show file tree
Hide file tree
Showing 2 changed files with 508 additions and 361 deletions.
80 changes: 52 additions & 28 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,19 +656,7 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
self.collect_params().initialize(init, ctx, verbose, force_reinit)

def hybridize(self, active=True, **kwargs):
"""Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
non-hybrid children.
Parameters
----------
active : bool, default True
Whether to turn hybrid on or off.
static_alloc : bool, default False
Statically allocate memory to improve speed. Memory usage may increase.
static_shape : bool, default False
Optimize for invariant input shapes between iterations. Must also
set static_alloc to True. Change of input shapes is still allowed
but slower.
""" Please refer description of HybridBlock hybridize().
"""
for cld in self._children.values():
cld.hybridize(active, **kwargs)
Expand Down Expand Up @@ -890,6 +878,8 @@ def __init__(self, prefix=None, params=None):
self._flags = []
self._callback = None
self._monitor_all = False
self._backend = None
self._backend_args = {}

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -935,7 +925,6 @@ def _build_cache(self, *args):
data_names = {data.name: i for i, data in enumerate(data)}
params = self.collect_params()
input_names = out.list_inputs()

param_names = set(params.keys())
expected_names = set(input_names)
for name in expected_names:
Expand Down Expand Up @@ -967,6 +956,26 @@ def _build_cache(self, *args):
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags

args, _ = _flatten(args, "input")
try:
for is_arg, i in self._cached_op_args:
if not is_arg:
i.data()
except DeferredInitializationError:
self._deferred_infer_shape(*args)
for is_arg, i in self._cached_op_args:
if not is_arg:
i._finish_deferred_init()

if self._backend:
ctx = args[0].context
# get list of params in the order of out.list_arguments
arg_array = [args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()]
# Partition the graph.
out = out.optimize_for(self._backend, arg_array, ctx, **self._backend_args)

self._cached_op = ndarray.CachedOp(out, flags)

def _deferred_infer_shape(self, *args):
Expand Down Expand Up @@ -1008,19 +1017,10 @@ def _call_cached_op(self, *args):
raise ValueError("The argument structure of HybridBlock does not match"
" the cached version. Stored format = {}, input format = {}"
.format(fmt, self._in_format))

args_without_none = [ele for ele in args if ele is not None]
try:
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
except DeferredInitializationError:
self._deferred_infer_shape(*args)
cargs = []
for is_arg, i in self._cached_op_args:
if is_arg:
cargs.append(args_without_none[i])
else:
i._finish_deferred_init()
cargs.append(i.data())
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
Expand All @@ -1040,7 +1040,32 @@ def register_child(self, block, name=None):
super(HybridBlock, self).register_child(block, name)
self._clear_cached_op()

def hybridize(self, active=True, **kwargs):
def hybridize(self, active=True, backend=None, backend_args=None, **kwargs):
"""Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
non-hybrid children.
Parameters
----------
active : bool, default True
Whether to turn hybrid on or off.
backend : str
The name of backend, as registered in `SubgraphBackendRegistry`, default None
backend_args : dict of arguments, optional
Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
static_alloc : bool, default False
Statically allocate memory to improve speed. Memory usage may increase.
static_shape : bool, default False
Optimize for invariant input shapes between iterations. Must also
set static_alloc to True. Change of input shapes is still allowed
but slower.
"""

self._backend = backend
if backend_args is not None:
assert isinstance(backend_args, dict), \
"HybridBlock hybridize requires backend_args to be a dictionary."
self._backend_args = backend_args

self._active = active
self._flags = list(kwargs.items())
self._clear_cached_op()
Expand Down Expand Up @@ -1160,7 +1185,6 @@ def forward(self, x, *args):
params = {k: v.data(ctx) for k, v in self._reg_params.items()}

return self.hybrid_forward(ndarray, x, *args, **params)

params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
return self.hybrid_forward(symbol, x, *args, **params)
Expand Down
Loading

0 comments on commit 5daf529

Please sign in to comment.