Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add deleting of args aux aux to Partition API
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Jun 2, 2020
1 parent 3efacd2 commit 5c57d6b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 20 deletions.
40 changes: 21 additions & 19 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,29 +1006,16 @@ def _build_cache(self, *args):
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)

data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags

args, _ = _flatten(args, "input")
try:
for is_arg, i in self._cached_op_args:
if not is_arg:
i.data()
for name in input_names:
if name in params:
params[name].data()
except DeferredInitializationError:
self._deferred_infer_shape(*args)
for is_arg, i in self._cached_op_args:
if not is_arg:
i._finish_deferred_init()
for name in input_names:
if name in params:
params[name]._finish_deferred_init()

if self._backend:
ctx = args[0].context
Expand All @@ -1037,10 +1024,25 @@ def _build_cache(self, *args):
for name in out.list_arguments()]
aux_array = [args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()]

# Partition the graph.
out = out.optimize_for(self._backend, arg_array, aux_array, ctx, **self._backend_opts)
#update cached graph with partitioned graph
self._cached_graph = data, out

input_names = out.list_inputs()
data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags
self._cached_op = ndarray.CachedOp(out, flags)


Expand Down
32 changes: 31 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,7 +1554,37 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
aux[idx] = NDArray(NDArrayHandle(new_aux_handle[i]))
else:
aux.append(NDArray(NDArrayHandle(new_aux_handle[i])))
return Symbol(out)

new_sym = Symbol(out)

if args is not None:
new_arg_names = new_sym.list_arguments()
deleted_arg_names = set([item for item in arg_names
if item not in set(new_arg_names)])
if isinstance(args, dict):
for a_n in deleted_arg_names:
if a_n in args:
args.pop(a_n)
elif isinstance(args, list):
indices_to_delete = [i for i, name in enumerate(arg_names) if name in deleted_arg_names]
for idx in reversed(indices_to_delete):
args.pop(idx)

if aux is not None:
new_aux_names = new_sym.list_auxiliary_states()
deleted_aux_names = set([item for item in aux_names
if item not in set(new_aux_names)])

if isinstance(aux, dict):
for a_n in deleted_aux_names:
if a_n in aux:
aux.pop(a_n)
elif isinstance(args, list):
indices_to_delete = [i for i, name in enumerate(aux_names) if name in deleted_aux_names]
for idx in reversed(indices_to_delete):
aux.pop(idx)

return new_sym


# pylint: disable=too-many-locals
Expand Down

0 comments on commit 5c57d6b

Please sign in to comment.