Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Partition API adding and deleting new params to Block and Symbol (#18405
Browse files Browse the repository at this point in the history
)

* Add deleting of args aux aux to Partition API

Signed-off-by: Serge Panev <[email protected]>

* Delete args from Block.params

Signed-off-by: Serge Panev <[email protected]>

* Fix to use arg/auxdict when optimize_for is called in HybridBlock

Signed-off-by: Serge Panev <[email protected]>

* Address PR comments

Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L authored Jul 13, 2020
1 parent 19e373d commit 9c5b95a
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 31 deletions.
105 changes: 76 additions & 29 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,41 +1040,69 @@ def _build_cache(self, *args):
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)

data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags

args, _ = _flatten(args, "input")
try:
for is_arg, i in self._cached_op_args:
if not is_arg:
i.data()
for name in input_names:
if name in params:
params[name].data()
except DeferredInitializationError:
self._deferred_infer_shape(*args)
for is_arg, i in self._cached_op_args:
if not is_arg:
i._finish_deferred_init()
for name in input_names:
if name in params:
params[name]._finish_deferred_init()

arg_dict, aux_dict = dict(), dict()
if self._backend:
ctx = args[0].context
# get list of params in the order of out.list_arguments
arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()}
aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()}
arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()})
aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()})
# Partition the graph.
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)

#update cached graph with partitioned graph
self._cached_graph = data, out

input_names = out.list_inputs()
data_indices = []
param_indices = []

# In the default case, _cached_ops_args contains all the parameters from params (the sets are identical)
# In the case of Partition API optimized graph _cached_ops_args might contain some parameters from params,
# might contain some new parameters created during optimization and added to `arg_dict/aux_dict`,
# and might not contain some parameters that were deleted during optimization.
self._cached_op_args = []
for i, name in enumerate(input_names):
pair = None
if name in data_names:
data_indices.append(i)
pair = (True, data_names[name])
else:
param_indices.append(i)
if name in params:
param = params[name]
else:
# The param is missing from the original params dictionary, which means the param must have
# been added by the Partition API backend
if name in arg_dict or name:
param_data = arg_dict[name]
elif name in aux_dict:
param_data = aux_dict[name]
else:
raise RuntimeError('A parameter was added to the graph during optimization but it was not '
'added to the parameter dicts.\n'
'Please check the backend.')

param = Parameter(name)
param._load_init(param_data, args[0].context)
pair = (False, param)

self._cached_op_args.append(pair)

flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags
self._cached_op = ndarray.CachedOp(out, flags)


Expand Down Expand Up @@ -1321,12 +1349,14 @@ def export(self, path, epoch=0, remove_amp_cast=True):
arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
arg_dict = {}
for param in self.collect_params().values():
if param.name in arg_names:
arg_dict['arg:%s'%param.name] = param._reduce()
else:
assert param.name in aux_names
arg_dict['aux:%s'%param.name] = param._reduce()
for is_arg, param in self._cached_op_args:
if not is_arg:
name = param.name
if name in arg_names:
arg_dict['arg:{}'.format(name)] = param._reduce()
else:
assert name in aux_names
arg_dict['aux:{}'.format(name)] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
params_filename = '%s-%04d.params'%(path, epoch)
save_fn(params_filename, arg_dict)
Expand Down Expand Up @@ -1437,6 +1467,23 @@ def hybrid_forward(self, F, x, *args, **kwargs):
# pylint: disable= invalid-name
raise NotImplementedError

def reset_ctx(self, ctx):
"""Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args.
Parameters
----------
ctx : Context or list of Context, default :py:meth:`context.current_context()`.
Assign Parameter to given context. If ctx is a list of Context, a
copy will be made for each context.
"""
params = self.collect_params()
if self._cached_op:
for p in self._cached_op_args:
# resetting parameters creating by the partitioning backend
if p.name not in params:
p.reset_ctx(ctx)
for p in params.values():
p.reset_ctx(ctx)

class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
Expand Down
32 changes: 30 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,8 +1544,36 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' +
'Provide a dictionary to the aux argument to optimize_for')

# return modified symbol
return Symbol(out)
new_sym = Symbol(out)

arg_names = self.list_arguments()
new_arg_names = new_sym.list_arguments()
deleted_arg_names = set([item for item in arg_names
if item not in set(new_arg_names)])

if len(deleted_arg_names) > 0:
if args is not None:
for a_n in deleted_arg_names:
if a_n in args:
args.pop(a_n)
else:
warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' +
'Please ensure that your model weights match the newly optimized model.')

aux_names = self.list_auxiliary_states()
new_aux_names = new_sym.list_auxiliary_states()
deleted_aux_names = set([item for item in aux_names
if item not in set(new_aux_names)])
if len(deleted_aux_names) > 0:
if aux is not None:
for a_n in deleted_aux_names:
if a_n in aux:
aux.pop(a_n)
else:
warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' +
'Please ensure that your model weights match the newly optimized model.')

return new_sym

# pylint: disable=too-many-locals
def _simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
Expand Down

0 comments on commit 9c5b95a

Please sign in to comment.