Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Oct 22, 2019
1 parent 04fd54f commit 33e00ed
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,35 @@ def forward(self, x):
net2 = Network()
net2.load_parameters('tmp.params')

@with_seed()
def test_save_load_with_shared_params():
class B(mx.gluon.Block):
def __init__(self, params=None):
super().__init__(params=params)

with self.name_scope():
self.weight = self.params.get('weight', shape=(10, 10))

class C(mx.gluon.Block):
def __init__(self, b1, b2):
super().__init__()
self.b1 = b1
self.b2 = b2

b1 = B()
b2 = B(b1.collect_params())
c = C(b1, b2)
c.initialize()
c.save_parameters('tmp.params')

params = mx.nd.load('tmp.params')
assert len(params) == 1 # Only a single copy of the shared parameter is saved

b1 = B()
b2 = B(b1.collect_params())
c = C(b1, b2)
c.load_parameters('tmp.params')

@with_seed()
def test_symbol_block_save_load():
class Net(gluon.HybridBlock):
Expand Down

0 comments on commit 33e00ed

Please sign in to comment.