From 6f0ff5658bed14aeb630156c3c2452d83d1f1569 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Wed, 23 Oct 2019 20:37:01 +0000 Subject: [PATCH] Make deduplication during saving optional --- python/mxnet/gluon/block.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index e12b0b6d2513..01d32bf0d126 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -413,7 +413,7 @@ def _collect_params_with_prefix(self, prefix=''): ret.update(child._collect_params_with_prefix(prefix + name)) return ret - def save_parameters(self, filename): + def save_parameters(self, filename, deduplicate=False): """Save parameters to file. Saved parameters can only be loaded with `load_parameters`. Note that this @@ -424,6 +424,10 @@ def save_parameters(self, filename): ---------- filename : str Path to file. + deduplicate : bool, default False + If True, save shared parameters only once. Otherwise, if a Block + contains multiple sub-blocks that share parameters, each of the + shared parameters will be separately saved for every sub-block. References ---------- @@ -432,15 +436,16 @@ def save_parameters(self, filename): """ params = self._collect_params_with_prefix() - # Shared parameters are stored only a single time as of MXNet 1.6. - # Shared parameters are registered under multiple prefixes returned by - # _collect_params_with_prefix. We select a single one and only store - # it. In load_parameters it is sufficient for a shared parameter to - # only set it for a single prefix. - reverse_params = {v: k for k, v in params.items()} - params = {v: k for k, v in reverse_params.items()} - - arg_dict = {key : val._reduce() for key, val in params.items()} + if deduplicate: + # Shared parameters are stored only a single time as of MXNet 1.6. + # Shared parameters are registered under multiple prefixes returned by + # _collect_params_with_prefix. We select a single one and only store + # it. In load_parameters it is sufficient for a shared parameter to + # only set it for a single prefix. + reverse_params = {v: k for k, v in params.items()} + params = {v: k for k, v in reverse_params.items()} + + arg_dict = {key: val._reduce() for key, val in params.items()} save_fn = _mx_npx.save if is_np_array() else ndarray.save save_fn(filename, arg_dict)