Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add deleting of args aux aux to Partition API
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Jun 1, 2020
1 parent 3efacd2 commit 32dec75
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 21 deletions.
40 changes: 21 additions & 19 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,29 +1006,16 @@ def _build_cache(self, *args):
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)

data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags

args, _ = _flatten(args, "input")
try:
for is_arg, i in self._cached_op_args:
if not is_arg:
i.data()
for name in input_names:
if name in params:
params[name].data()
except DeferredInitializationError:
self._deferred_infer_shape(*args)
for is_arg, i in self._cached_op_args:
if not is_arg:
i._finish_deferred_init()
for name in input_names:
if name in params:
params[name]._finish_deferred_init()

if self._backend:
ctx = args[0].context
Expand All @@ -1037,10 +1024,25 @@ def _build_cache(self, *args):
for name in out.list_arguments()]
aux_array = [args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()]

# Partition the graph.
out = out.optimize_for(self._backend, arg_array, aux_array, ctx, **self._backend_opts)
#update cached graph with partitioned graph
self._cached_graph = data, out

input_names = out.list_inputs()
data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags
self._cached_op = ndarray.CachedOp(out, flags)


Expand Down
46 changes: 44 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,8 @@ def _gen_atomic_symbol(self):


# pylint: disable=too-many-locals
def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
def optimize_for(self, backend, args=None, aux=None, ctx=None,
deleted_args=None, deleted_aux=None, **kwargs):
"""Partitions current symbol and optimizes it for a given backend,
returns new partitioned symbol.
Expand All @@ -1473,6 +1474,14 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
ctx : Context, optional
Device context, used to infer stypes
deleted_args : list, optional
List receiving names of deleted args.
To be provided only if the backend can delete arguments.
deleted_aux : list, optional
List receiving names of deleted aux.
To be provided only if the backend can delete auxiliary states.
kwargs : optional arguments
Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
Expand Down Expand Up @@ -1554,7 +1563,40 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
aux[idx] = NDArray(NDArrayHandle(new_aux_handle[i]))
else:
aux.append(NDArray(NDArrayHandle(new_aux_handle[i])))
return Symbol(out)

new_sym = Symbol(out)

new_arg_names = new_sym.list_arguments()
deleted_arg_names = [item for item in arg_names
if item not in set(new_arg_names)]

new_aux_names = new_sym.list_auxiliary_states()
deleted_aux_names = [item for item in aux_names
if item not in set(new_aux_names)]

if isinstance(args, dict):
for a_n in deleted_arg_names:
if a_n in args:
args.pop(a_n)

if isinstance(aux, dict):
for a_n in deleted_aux_names:
if a_n in aux:
aux.pop(a_n)

if deleted_args is not None:
if isinstance(deleted_args, list):
deleted_args.extend(deleted_arg_names)
else:
raise ValueError('deleted_args has to be a list.')

if deleted_aux is not None:
if isinstance(deleted_aux, list):
deleted_aux.extend(deleted_aux_names)
else:
raise ValueError('deleted_args has to be a list.')

return new_sym


# pylint: disable=too-many-locals
Expand Down

0 comments on commit 32dec75

Please sign in to comment.