Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add deleting of args aux aux to Partition API
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Jun 29, 2020
1 parent b12abbf commit 832dd1a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 21 deletions.
39 changes: 20 additions & 19 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,29 +1040,16 @@ def _build_cache(self, *args):
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)

data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags

args, _ = _flatten(args, "input")
try:
for is_arg, i in self._cached_op_args:
if not is_arg:
i.data()
for name in input_names:
if name in params:
params[name].data()
except DeferredInitializationError:
self._deferred_infer_shape(*args)
for is_arg, i in self._cached_op_args:
if not is_arg:
i._finish_deferred_init()
for name in input_names:
if name in params:
params[name]._finish_deferred_init()

if self._backend:
ctx = args[0].context
Expand All @@ -1075,6 +1062,20 @@ def _build_cache(self, *args):
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)
#update cached graph with partitioned graph
self._cached_graph = data, out

input_names = out.list_inputs()
data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags
self._cached_op = ndarray.CachedOp(out, flags)


Expand Down
31 changes: 29 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,8 +1544,35 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' +
'Provide a dictionary to the aux argument to optimize_for')

# return modified symbol
return Symbol(out)
new_sym = Symbol(out)

arg_names = self.list_arguments()
new_arg_names = new_sym.list_arguments()
deleted_arg_names = set([item for item in arg_names
if item not in set(new_arg_names)])

if len(deleted_arg_names) > 0:
if args is not None:
for a_n in deleted_arg_names:
if a_n in args:
args.pop(a_n)
else:
warnings.warn('optimize_for deleted some argument. \n' +
'Provide a dictionary to the arg argument to optimize_for')
aux_names = self.list_auxiliary_states()
new_aux_names = new_sym.list_auxiliary_states()
deleted_aux_names = set([item for item in aux_names
if item not in set(new_aux_names)])
if len(deleted_aux_names) > 0:
if aux is not None:
for a_n in deleted_aux_names:
if a_n in aux:
aux.pop(a_n)
else:
warnings.warn('optimize_for deleted some aux argument. \n' +
'Provide a dictionary to the aux argument to optimize_for')

return new_sym

# pylint: disable=too-many-locals
def _simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
Expand Down

0 comments on commit 832dd1a

Please sign in to comment.