Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix to use arg/auxdict when optimize_for is called in HybridBlock
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Jul 3, 2020
1 parent 2b6bab7 commit 0e9a920
Showing 1 changed file with 51 additions and 25 deletions.
76 changes: 51 additions & 25 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,27 +1051,16 @@ def _build_cache(self, *args):
if name in params:
params[name]._finish_deferred_init()

arg_dict, aux_dict = dict(), dict()
ctx = args[0].context
if self._backend:
ctx = args[0].context
# get list of params in the order of out.list_arguments
arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()}
aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()}
arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()})
aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()})
# Partition the graph.
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)
# BFS to delete the delete args/aux from the block Params and its children's Params
input_names = out.list_inputs()
queue = [self]
while len(queue) > 0:
curr_block = queue.pop(0)
curr_params = curr_block.params if isinstance(curr_block.params, dict) else curr_block.params._params
curr_params_names = list(curr_params.keys())
for k in curr_params_names:
if k not in input_names:
curr_params.pop(k)

queue.extend(curr_block._children.values())

#update cached graph with partitioned graph
self._cached_graph = data, out
Expand All @@ -1081,12 +1070,30 @@ def _build_cache(self, *args):
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
pair = None
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
pair = (True, data_names[name])
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
if name in params:
param = params[name]
else:
assert self._backend, "Parameter " + name + " is missing from block params"
if name in arg_dict or name:
param_data = arg_dict[name]
elif name in aux_dict:
param_data = aux_dict[name]
else:
raise RuntimeError('Expected inputs missing from arg and aux after partioning. '
'Please check the backend.')

param = Parameter(name)
param._load_init(param_data, ctx)
pair = (False, param)

self._cached_op_args.append(pair)

flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags
self._cached_op = ndarray.CachedOp(out, flags)
Expand Down Expand Up @@ -1335,12 +1342,14 @@ def export(self, path, epoch=0, remove_amp_cast=True):
arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
arg_dict = {}
for param in self.collect_params().values():
if param.name in arg_names:
arg_dict['arg:%s'%param.name] = param._reduce()
else:
assert param.name in aux_names
arg_dict['aux:%s'%param.name] = param._reduce()
for is_arg, param in self._cached_op_args:
if not is_arg:
name = param.name
if name in arg_names:
arg_dict['arg:{}'.format(name)] = param._reduce()
else:
assert name in aux_names
arg_dict['aux:{}'.format(name)] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
params_filename = '%s-%04d.params'%(path, epoch)
save_fn(params_filename, arg_dict)
Expand Down Expand Up @@ -1451,6 +1460,23 @@ def hybrid_forward(self, F, x, *args, **kwargs):
# pylint: disable= invalid-name
raise NotImplementedError

def reset_ctx(self, ctx):
"""Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args.
Parameters
----------
ctx : Context or list of Context, default :py:meth:`context.current_context()`.
Assign Parameter to given context. If ctx is a list of Context, a
copy will be made for each context.
"""
params = self.collect_params()
if self._cached_op:
for p in self._cached_op_args:
# resetting parameters creating by the partitioning backend
if p.name not in params:
p.reset_ctx(ctx)
for p in params.values():
p.reset_ctx(ctx)

class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
Expand Down

0 comments on commit 0e9a920

Please sign in to comment.