Skip to content

Commit

Permalink
[Gluon] Don't serialize shared parameters twice (apache#16582)
Browse files Browse the repository at this point in the history
Add deduplicate argument (default of False) to save_parameters.
  • Loading branch information
leezu authored and sxjscience committed Oct 26, 2019
1 parent 4dbf421 commit 0fa8f17
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 6 deletions.
35 changes: 29 additions & 6 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import copy
import warnings
import re
from collections import OrderedDict
from collections import OrderedDict, defaultdict

from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer
Expand Down Expand Up @@ -413,7 +413,7 @@ def _collect_params_with_prefix(self, prefix=''):
ret.update(child._collect_params_with_prefix(prefix + name))
return ret

def save_parameters(self, filename):
def save_parameters(self, filename, deduplicate=False):
"""Save parameters to file.
Saved parameters can only be loaded with `load_parameters`. Note that this
Expand All @@ -424,14 +424,28 @@ def save_parameters(self, filename):
----------
filename : str
Path to file.
deduplicate : bool, default False
If True, save shared parameters only once. Otherwise, if a Block
contains multiple sub-blocks that share parameters, each of the
shared parameters will be separately saved for every sub-block.
References
----------
`Saving and Loading Gluon Models \
<https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_
"""
params = self._collect_params_with_prefix()
arg_dict = {key : val._reduce() for key, val in params.items()}

if deduplicate:
# Shared parameters are stored only a single time as of MXNet 1.6.
# Shared parameters are registered under multiple prefixes returned by
# _collect_params_with_prefix. We select a single one and only store
# it. In load_parameters it is sufficient for a shared parameter to
# only set it for a single prefix.
reverse_params = {v: k for k, v in params.items()}
params = {v: k for k, v in reverse_params.items()}

arg_dict = {key: val._reduce() for key, val in params.items()}
save_fn = _mx_npx.save if is_np_array() else ndarray.save
save_fn(filename, arg_dict)

Expand Down Expand Up @@ -510,15 +524,24 @@ def load_parameters(self, filename, ctx=None, allow_missing=False,

if not any('.' in i for i in loaded.keys()):
# legacy loading
del loaded
loaded = None # This should be changed to `del loaded` when dropping Python 2
self.collect_params().load(
filename, ctx, allow_missing, ignore_extra, self.prefix,
cast_dtype=cast_dtype, dtype_source=dtype_source)
return

if not allow_missing:
for name in params.keys():
assert name in loaded, \
# Shared parameters are stored only a single time as of MXNet 1.6.
# We thus retrieve all prefixes (through _collect_params_with_prefix)
# that a shared parameter is used with. Check that there are no
# missing parameters that were not yet already loaded from the
# shared version.
params_inv = defaultdict(list)
for k, v in params.items():
params_inv[v].append(k)

for name, param in params.items():
assert any(p in loaded for p in params_inv[param]), \
"Parameter '%s' is missing in file '%s', which contains parameters: %s. " \
"Set allow_missing=True to ignore missing parameters."%(
name, filename, _brief_print_list(loaded.keys()))
Expand Down
40 changes: 40 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,46 @@ def forward(self, x):
net2 = Network()
net2.load_parameters('tmp.params')

@with_seed()
def test_save_load_deduplicate_with_shared_params():
class B(mx.gluon.Block):
def __init__(self, params=None):
super(B, self).__init__(params=params)

with self.name_scope():
self.weight = self.params.get('weight', shape=(10, 10))

class C(mx.gluon.Block):
def __init__(self, b1, b2):
super(C, self).__init__()
self.b1 = b1
self.b2 = b2

b1 = B()
b2 = B(b1.collect_params())
c = C(b1, b2)
c.initialize()
c.save_parameters('tmp.params', deduplicate=True)

params = mx.nd.load('tmp.params')
assert len(params) == 1 # Only a single copy of the shared parameter is saved

b1 = B()
b2 = B(b1.collect_params())
c = C(b1, b2)
c.load_parameters('tmp.params')

# Test default behavior
c.save_parameters('tmp2.params', deduplicate=False)

params = mx.nd.load('tmp2.params')
assert len(params) == 2 # Only a single copy of the shared parameter is saved

b1 = B()
b2 = B(b1.collect_params())
c = C(b1, b2)
c.load_parameters('tmp2.params')

@with_seed()
def test_symbol_block_save_load():
class Net(gluon.HybridBlock):
Expand Down

0 comments on commit 0fa8f17

Please sign in to comment.