From 11a7a5c62c8532b510079e770890faa2dfd4a8b6 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Thu, 2 Jul 2020 21:54:20 -0700 Subject: [PATCH] Fix to use arg/auxdict when optimize_for is called in HybridBlock Signed-off-by: Serge Panev --- python/mxnet/gluon/block.py | 74 +++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 24 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 89406762a64e..fe44a848d582 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1051,27 +1051,16 @@ def _build_cache(self, *args): if name in params: params[name]._finish_deferred_init() + arg_dict, aux_dict = dict(), dict() if self._backend: ctx = args[0].context # get list of params in the order of out.list_arguments - arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_arguments()} - aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_auxiliary_states()} + arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_arguments()}) + aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_auxiliary_states()}) # Partition the graph. out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) - # BFS to delete the delete args/aux from the block Params and its children's Params - input_names = out.list_inputs() - queue = [self] - while len(queue) > 0: - curr_block = queue.pop(0) - curr_params = curr_block.params if isinstance(curr_block.params, dict) else curr_block.params._params - curr_params_names = list(curr_params.keys()) - for k in curr_params_names: - if k not in input_names: - curr_params.pop(k) - - queue.extend(curr_block._children.values()) #update cached graph with partitioned graph self._cached_graph = data, out @@ -1081,12 +1070,30 @@ def _build_cache(self, *args): param_indices = [] self._cached_op_args = [] for i, name in enumerate(input_names): + pair = None if name in data_names: data_indices.append(i) - self._cached_op_args.append((True, data_names[name])) + pair = (True, data_names[name]) else: param_indices.append(i) - self._cached_op_args.append((False, params[name])) + if name in params: + param = params[name] + else: + assert self._backend, "Parameter " + name + " is missing from block params" + if name in arg_dict or name: + param_data = arg_dict[name] + elif name in aux_dict: + param_data = aux_dict[name] + else: + raise RuntimeError('Expected inputs missing from arg and aux after partioning. ' + 'Please check the backend.') + + param = Parameter(name) + param._load_init(param_data, args[0].context) + pair = (False, param) + + self._cached_op_args.append(pair) + flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ self._flags self._cached_op = ndarray.CachedOp(out, flags) @@ -1335,12 +1342,14 @@ def export(self, path, epoch=0, remove_amp_cast=True): arg_names = set(sym.list_arguments()) aux_names = set(sym.list_auxiliary_states()) arg_dict = {} - for param in self.collect_params().values(): - if param.name in arg_names: - arg_dict['arg:%s'%param.name] = param._reduce() - else: - assert param.name in aux_names - arg_dict['aux:%s'%param.name] = param._reduce() + for is_arg, param in self._cached_op_args: + if not is_arg: + name = param.name + if name in arg_names: + arg_dict['arg:{}'.format(name)] = param._reduce() + else: + assert name in aux_names + arg_dict['aux:{}'.format(name)] = param._reduce() save_fn = _mx_npx.save if is_np_array() else ndarray.save params_filename = '%s-%04d.params'%(path, epoch) save_fn(params_filename, arg_dict) @@ -1451,6 +1460,23 @@ def hybrid_forward(self, F, x, *args, **kwargs): # pylint: disable= invalid-name raise NotImplementedError + def reset_ctx(self, ctx): + """Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args. + + Parameters + ---------- + ctx : Context or list of Context, default :py:meth:`context.current_context()`. + Assign Parameter to given context. If ctx is a list of Context, a + copy will be made for each context. + """ + params = self.collect_params() + if self._cached_op: + for p in self._cached_op_args: + # resetting parameters creating by the partitioning backend + if p.name not in params: + p.reset_ctx(ctx) + for p in params.values(): + p.reset_ctx(ctx) class SymbolBlock(HybridBlock): """Construct block from symbol. This is useful for using pre-trained models