-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Partition API adding and deleting new params to Block and Symbol #18405
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1040,41 +1040,62 @@ def _build_cache(self, *args): | |
warnings.warn("Parameter %s is not used by any computation. " | ||
"Is this intended?"%unused, stacklevel=4) | ||
|
||
data_indices = [] | ||
param_indices = [] | ||
self._cached_op_args = [] | ||
for i, name in enumerate(input_names): | ||
if name in data_names: | ||
data_indices.append(i) | ||
self._cached_op_args.append((True, data_names[name])) | ||
else: | ||
param_indices.append(i) | ||
self._cached_op_args.append((False, params[name])) | ||
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ | ||
self._flags | ||
|
||
args, _ = _flatten(args, "input") | ||
try: | ||
for is_arg, i in self._cached_op_args: | ||
if not is_arg: | ||
i.data() | ||
for name in input_names: | ||
if name in params: | ||
params[name].data() | ||
except DeferredInitializationError: | ||
self._deferred_infer_shape(*args) | ||
for is_arg, i in self._cached_op_args: | ||
if not is_arg: | ||
i._finish_deferred_init() | ||
for name in input_names: | ||
if name in params: | ||
params[name]._finish_deferred_init() | ||
|
||
arg_dict, aux_dict = dict(), dict() | ||
if self._backend: | ||
ctx = args[0].context | ||
# get list of params in the order of out.list_arguments | ||
arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() | ||
for name in out.list_arguments()} | ||
aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() | ||
for name in out.list_auxiliary_states()} | ||
arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() | ||
for name in out.list_arguments()}) | ||
aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() | ||
for name in out.list_auxiliary_states()}) | ||
# Partition the graph. | ||
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) | ||
|
||
#update cached graph with partitioned graph | ||
self._cached_graph = data, out | ||
|
||
input_names = out.list_inputs() | ||
data_indices = [] | ||
param_indices = [] | ||
self._cached_op_args = [] | ||
for i, name in enumerate(input_names): | ||
pair = None | ||
if name in data_names: | ||
data_indices.append(i) | ||
pair = (True, data_names[name]) | ||
else: | ||
param_indices.append(i) | ||
if name in params: | ||
param = params[name] | ||
else: | ||
assert self._backend, "Parameter " + name + " is missing from block params" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the case that should never happen? When a param name is not in |
||
if name in arg_dict or name: | ||
param_data = arg_dict[name] | ||
elif name in aux_dict: | ||
param_data = aux_dict[name] | ||
else: | ||
raise RuntimeError('Expected inputs missing from arg and aux after partioning. ' | ||
Kh4L marked this conversation as resolved.
Show resolved
Hide resolved
|
||
'Please check the backend.') | ||
|
||
param = Parameter(name) | ||
param._load_init(param_data, args[0].context) | ||
pair = (False, param) | ||
|
||
self._cached_op_args.append(pair) | ||
|
||
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ | ||
self._flags | ||
self._cached_op = ndarray.CachedOp(out, flags) | ||
|
||
|
||
|
@@ -1321,12 +1342,14 @@ def export(self, path, epoch=0, remove_amp_cast=True): | |
arg_names = set(sym.list_arguments()) | ||
aux_names = set(sym.list_auxiliary_states()) | ||
arg_dict = {} | ||
for param in self.collect_params().values(): | ||
if param.name in arg_names: | ||
arg_dict['arg:%s'%param.name] = param._reduce() | ||
else: | ||
assert param.name in aux_names | ||
arg_dict['aux:%s'%param.name] = param._reduce() | ||
for is_arg, param in self._cached_op_args: | ||
if not is_arg: | ||
name = param.name | ||
if name in arg_names: | ||
arg_dict['arg:{}'.format(name)] = param._reduce() | ||
else: | ||
assert name in aux_names | ||
arg_dict['aux:{}'.format(name)] = param._reduce() | ||
save_fn = _mx_npx.save if is_np_array() else ndarray.save | ||
params_filename = '%s-%04d.params'%(path, epoch) | ||
save_fn(params_filename, arg_dict) | ||
|
@@ -1437,6 +1460,23 @@ def hybrid_forward(self, F, x, *args, **kwargs): | |
# pylint: disable= invalid-name | ||
raise NotImplementedError | ||
|
||
def reset_ctx(self, ctx): | ||
"""Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args. | ||
|
||
Parameters | ||
---------- | ||
ctx : Context or list of Context, default :py:meth:`context.current_context()`. | ||
Assign Parameter to given context. If ctx is a list of Context, a | ||
copy will be made for each context. | ||
""" | ||
params = self.collect_params() | ||
if self._cached_op: | ||
for p in self._cached_op_args: | ||
# resetting parameters creating by the partitioning backend | ||
if p.name not in params: | ||
p.reset_ctx(ctx) | ||
for p in params.values(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain why we need to loop over There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although i guess if we delete a param, then it will still be in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I will add some comments to clarify |
||
p.reset_ctx(ctx) | ||
|
||
class SymbolBlock(HybridBlock): | ||
"""Construct block from symbol. This is useful for using pre-trained models | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1544,8 +1544,35 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): | |
raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' + | ||
'Provide a dictionary to the aux argument to optimize_for') | ||
|
||
# return modified symbol | ||
return Symbol(out) | ||
new_sym = Symbol(out) | ||
|
||
arg_names = self.list_arguments() | ||
new_arg_names = new_sym.list_arguments() | ||
deleted_arg_names = set([item for item in arg_names | ||
if item not in set(new_arg_names)]) | ||
|
||
if len(deleted_arg_names) > 0: | ||
if args is not None: | ||
for a_n in deleted_arg_names: | ||
if a_n in args: | ||
args.pop(a_n) | ||
else: | ||
warnings.warn('optimize_for deleted some argument. \n' + | ||
'Provide a dictionary to the arg argument to optimize_for') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we clarify this error message to something like:
Should we print the names of the deleted_arg_names in the message too (or is that overkill)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion! I think that this new message will be sufficient, printing the list of deleted arg names could be potentially very long and flood the stderr |
||
aux_names = self.list_auxiliary_states() | ||
new_aux_names = new_sym.list_auxiliary_states() | ||
deleted_aux_names = set([item for item in aux_names | ||
if item not in set(new_aux_names)]) | ||
if len(deleted_aux_names) > 0: | ||
if aux is not None: | ||
for a_n in deleted_aux_names: | ||
if a_n in aux: | ||
aux.pop(a_n) | ||
else: | ||
warnings.warn('optimize_for deleted some aux argument. \n' + | ||
'Provide a dictionary to the aux argument to optimize_for') | ||
Kh4L marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return new_sym | ||
|
||
# pylint: disable=too-many-locals | ||
def _simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you initialize with
dict()
and then call update on an empty dictionary instead of just assigning?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arg_dict
andaux_dict
could otherwise be undefined below theassert
in line 1083. This could be a SyntaxError or linter error?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right, also Python scopes rules are sometimes a bit unsettling. So I thought that this would make it clearer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, thanks!