Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix to use arg/auxdict when optimize_for is called in block
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Jul 2, 2020
1 parent 2b6bab7 commit ab422c4
Showing 1 changed file with 28 additions and 25 deletions.
53 changes: 28 additions & 25 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ def __init__(self):
super(HybridBlock, self).__init__()
self._cached_graph = ()
self._cached_op = None
self._cached_op_args = None
self._out_format = None
self._in_format = None
self._called_infer_shape_already = False
Expand Down Expand Up @@ -1051,27 +1052,16 @@ def _build_cache(self, *args):
if name in params:
params[name]._finish_deferred_init()

arg_dict, aux_dict = dict(), dict()
if self._backend:
ctx = args[0].context
# get list of params in the order of out.list_arguments
arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()}
aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()}
arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()})
aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()})
# Partition the graph.
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)
# BFS to delete the delete args/aux from the block Params and its children's Params
input_names = out.list_inputs()
queue = [self]
while len(queue) > 0:
curr_block = queue.pop(0)
curr_params = curr_block.params if isinstance(curr_block.params, dict) else curr_block.params._params
curr_params_names = list(curr_params.keys())
for k in curr_params_names:
if k not in input_names:
curr_params.pop(k)

queue.extend(curr_block._children.values())

#update cached graph with partitioned graph
self._cached_graph = data, out
Expand All @@ -1081,12 +1071,24 @@ def _build_cache(self, *args):
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
pair = None
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
pair = (True, data_names[name])
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
if not self._backend:
pair = (False, params[name])
else:
if name in arg_dict:
pair = (False, arg_dict[name])
elif name in aux_dict:
pair = (False, aux_dict[name])
else:
raise RuntimeError('Expected inputs missing from arg and aux after partioning.'
'Please check the backend.')
self._cached_op_args.append(pair + (name,))

flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags
self._cached_op = ndarray.CachedOp(out, flags)
Expand Down Expand Up @@ -1134,7 +1136,7 @@ def _call_cached_op(self, *args):

args_without_none = [ele for ele in args if ele is not None]
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
for is_arg, i, _ in self._cached_op_args]
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
Expand Down Expand Up @@ -1335,12 +1337,13 @@ def export(self, path, epoch=0, remove_amp_cast=True):
arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
arg_dict = {}
for param in self.collect_params().values():
if param.name in arg_names:
arg_dict['arg:%s'%param.name] = param._reduce()
else:
assert param.name in aux_names
arg_dict['aux:%s'%param.name] = param._reduce()
for is_arg, param, name in self._cached_op_args:
if not is_arg:
if name in arg_names:
arg_dict['arg:{}'.format(name)] = param._reduce()
else:
assert name in aux_names
arg_dict['aux:{}'.format(name)] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
params_filename = '%s-%04d.params'%(path, epoch)
save_fn(params_filename, arg_dict)
Expand Down

0 comments on commit ab422c4

Please sign in to comment.