From 5b9b033ea51c12962e14e039c44f959046c43633 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Mon, 22 Jun 2020 15:39:34 -0700 Subject: [PATCH] Delete args from Block.params Signed-off-by: Serge Panev --- python/mxnet/gluon/block.py | 12 ++++++++++++ python/mxnet/symbol/symbol.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 8a61ba572e90..45e14977e0c4 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1026,6 +1026,18 @@ def _build_cache(self, *args): for name in out.list_auxiliary_states()} # Partition the graph. out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) + # BFS to delete the delete args/aux from the block Params and its children's Params + input_names = out.list_inputs() + queue = [self] + while len(queue) > 0: + curr_block = queue.pop(0) + curr_params_names = list(curr_block.params._params.keys()) + for k in curr_params_names: + if k not in input_names: + curr_block.params._params.pop(k) + + queue.extend(curr_block._children.values()) + #update cached graph with partitioned graph self._cached_graph = data, out diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index ae20a0dcc91b..ec14761d75c0 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1550,7 +1550,7 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): arg_names = self.list_arguments() new_arg_names = new_sym.list_arguments() deleted_arg_names = set([item for item in arg_names - if item not in set(new_arg_names)]) + if item not in set(new_arg_names)]) if len(deleted_arg_names) > 0: if args is not None: @@ -1563,7 +1563,7 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): aux_names = self.list_auxiliary_states() new_aux_names = new_sym.list_auxiliary_states() deleted_aux_names = set([item for item in aux_names - if item not in set(new_aux_names)]) + if item not in set(new_aux_names)]) if len(deleted_aux_names) > 0: if aux is not None: for a_n in deleted_aux_names: