Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Make deduplication during saving optional
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Oct 24, 2019
1 parent c9b6a75 commit 6f0ff56
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _collect_params_with_prefix(self, prefix=''):
ret.update(child._collect_params_with_prefix(prefix + name))
return ret

def save_parameters(self, filename):
def save_parameters(self, filename, deduplicate=False):
"""Save parameters to file.
Saved parameters can only be loaded with `load_parameters`. Note that this
Expand All @@ -424,6 +424,10 @@ def save_parameters(self, filename):
----------
filename : str
Path to file.
deduplicate : bool, default False
If True, save shared parameters only once. Otherwise, if a Block
contains multiple sub-blocks that share parameters, each of the
shared parameters will be separately saved for every sub-block.
References
----------
Expand All @@ -432,15 +436,16 @@ def save_parameters(self, filename):
"""
params = self._collect_params_with_prefix()

# Shared parameters are stored only a single time as of MXNet 1.6.
# Shared parameters are registered under multiple prefixes returned by
# _collect_params_with_prefix. We select a single one and only store
# it. In load_parameters it is sufficient for a shared parameter to
# only set it for a single prefix.
reverse_params = {v: k for k, v in params.items()}
params = {v: k for k, v in reverse_params.items()}

arg_dict = {key : val._reduce() for key, val in params.items()}
if deduplicate:
# Shared parameters are stored only a single time as of MXNet 1.6.
# Shared parameters are registered under multiple prefixes returned by
# _collect_params_with_prefix. We select a single one and only store
# it. In load_parameters it is sufficient for a shared parameter to
# only set it for a single prefix.
reverse_params = {v: k for k, v in params.items()}
params = {v: k for k, v in reverse_params.items()}

arg_dict = {key: val._reduce() for key, val in params.items()}
save_fn = _mx_npx.save if is_np_array() else ndarray.save
save_fn(filename, arg_dict)

Expand Down

0 comments on commit 6f0ff56

Please sign in to comment.