diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index eff7dd754572..629ff22ec4e0 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -24,7 +24,7 @@ import copy import warnings import re -from collections import OrderedDict +from collections import OrderedDict, defaultdict from ..base import mx_real_t, MXNetError from .. import symbol, ndarray, initializer @@ -413,7 +413,7 @@ def _collect_params_with_prefix(self, prefix=''): ret.update(child._collect_params_with_prefix(prefix + name)) return ret - def save_parameters(self, filename): + def save_parameters(self, filename, deduplicate=False): """Save parameters to file. Saved parameters can only be loaded with `load_parameters`. Note that this @@ -424,6 +424,10 @@ def save_parameters(self, filename): ---------- filename : str Path to file. + deduplicate : bool, default False + If True, save shared parameters only once. Otherwise, if a Block + contains multiple sub-blocks that share parameters, each of the + shared parameters will be separately saved for every sub-block. References ---------- @@ -431,7 +435,17 @@ def save_parameters(self, filename): `_ """ params = self._collect_params_with_prefix() - arg_dict = {key : val._reduce() for key, val in params.items()} + + if deduplicate: + # Shared parameters are stored only a single time as of MXNet 1.6. + # Shared parameters are registered under multiple prefixes returned by + # _collect_params_with_prefix. We select a single one and only store + # it. In load_parameters it is sufficient for a shared parameter to + # only set it for a single prefix. + reverse_params = {v: k for k, v in params.items()} + params = {v: k for k, v in reverse_params.items()} + + arg_dict = {key: val._reduce() for key, val in params.items()} save_fn = _mx_npx.save if is_np_array() else ndarray.save save_fn(filename, arg_dict) @@ -510,15 +524,24 @@ def load_parameters(self, filename, ctx=None, allow_missing=False, if not any('.' in i for i in loaded.keys()): # legacy loading - del loaded + loaded = None # This should be changed to `del loaded` when dropping Python 2 self.collect_params().load( filename, ctx, allow_missing, ignore_extra, self.prefix, cast_dtype=cast_dtype, dtype_source=dtype_source) return if not allow_missing: - for name in params.keys(): - assert name in loaded, \ + # Shared parameters are stored only a single time as of MXNet 1.6. + # We thus retrieve all prefixes (through _collect_params_with_prefix) + # that a shared parameter is used with. Check that there are no + # missing parameters that were not yet already loaded from the + # shared version. + params_inv = defaultdict(list) + for k, v in params.items(): + params_inv[v].append(k) + + for name, param in params.items(): + assert any(p in loaded for p in params_inv[param]), \ "Parameter '%s' is missing in file '%s', which contains parameters: %s. " \ "Set allow_missing=True to ignore missing parameters."%( name, filename, _brief_print_list(loaded.keys())) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index f1d0cc7ac274..f1413e2b99c2 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1511,6 +1511,46 @@ def forward(self, x): net2 = Network() net2.load_parameters('tmp.params') +@with_seed() +def test_save_load_deduplicate_with_shared_params(): + class B(mx.gluon.Block): + def __init__(self, params=None): + super(B, self).__init__(params=params) + + with self.name_scope(): + self.weight = self.params.get('weight', shape=(10, 10)) + + class C(mx.gluon.Block): + def __init__(self, b1, b2): + super(C, self).__init__() + self.b1 = b1 + self.b2 = b2 + + b1 = B() + b2 = B(b1.collect_params()) + c = C(b1, b2) + c.initialize() + c.save_parameters('tmp.params', deduplicate=True) + + params = mx.nd.load('tmp.params') + assert len(params) == 1 # Only a single copy of the shared parameter is saved + + b1 = B() + b2 = B(b1.collect_params()) + c = C(b1, b2) + c.load_parameters('tmp.params') + + # Test default behavior + c.save_parameters('tmp2.params', deduplicate=False) + + params = mx.nd.load('tmp2.params') + assert len(params) == 2 # Only a single copy of the shared parameter is saved + + b1 = B() + b2 = B(b1.collect_params()) + c = C(b1, b2) + c.load_parameters('tmp2.params') + @with_seed() def test_symbol_block_save_load(): class Net(gluon.HybridBlock):