From 832dd1a5774c8abdc60108395fdfc308e4c3ea91 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 2 Jun 2020 00:17:41 -0700 Subject: [PATCH] Add deleting of args aux aux to Partition API Signed-off-by: Serge Panev --- python/mxnet/gluon/block.py | 39 ++++++++++++++++++----------------- python/mxnet/symbol/symbol.py | 31 ++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 2fda08067e0f..1f8fd0f4ec39 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1040,29 +1040,16 @@ def _build_cache(self, *args): warnings.warn("Parameter %s is not used by any computation. " "Is this intended?"%unused, stacklevel=4) - data_indices = [] - param_indices = [] - self._cached_op_args = [] - for i, name in enumerate(input_names): - if name in data_names: - data_indices.append(i) - self._cached_op_args.append((True, data_names[name])) - else: - param_indices.append(i) - self._cached_op_args.append((False, params[name])) - flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ - self._flags - args, _ = _flatten(args, "input") try: - for is_arg, i in self._cached_op_args: - if not is_arg: - i.data() + for name in input_names: + if name in params: + params[name].data() except DeferredInitializationError: self._deferred_infer_shape(*args) - for is_arg, i in self._cached_op_args: - if not is_arg: - i._finish_deferred_init() + for name in input_names: + if name in params: + params[name]._finish_deferred_init() if self._backend: ctx = args[0].context @@ -1075,6 +1062,20 @@ def _build_cache(self, *args): out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) #update cached graph with partitioned graph self._cached_graph = data, out + + input_names = out.list_inputs() + data_indices = [] + param_indices = [] + self._cached_op_args = [] + for i, name in enumerate(input_names): + if name in data_names: + data_indices.append(i) + self._cached_op_args.append((True, data_names[name])) + else: + param_indices.append(i) + self._cached_op_args.append((False, params[name])) + flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ + self._flags self._cached_op = ndarray.CachedOp(out, flags) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 39b8799ce155..5ffa9ac9d010 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1544,8 +1544,35 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' + 'Provide a dictionary to the aux argument to optimize_for') - # return modified symbol - return Symbol(out) + new_sym = Symbol(out) + + arg_names = self.list_arguments() + new_arg_names = new_sym.list_arguments() + deleted_arg_names = set([item for item in arg_names + if item not in set(new_arg_names)]) + + if len(deleted_arg_names) > 0: + if args is not None: + for a_n in deleted_arg_names: + if a_n in args: + args.pop(a_n) + else: + warnings.warn('optimize_for deleted some argument. \n' + + 'Provide a dictionary to the arg argument to optimize_for') + aux_names = self.list_auxiliary_states() + new_aux_names = new_sym.list_auxiliary_states() + deleted_aux_names = set([item for item in aux_names + if item not in set(new_aux_names)]) + if len(deleted_aux_names) > 0: + if aux is not None: + for a_n in deleted_aux_names: + if a_n in aux: + aux.pop(a_n) + else: + warnings.warn('optimize_for deleted some aux argument. \n' + + 'Provide a dictionary to the aux argument to optimize_for') + + return new_sym # pylint: disable=too-many-locals def _simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,