Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add list_ctx to ParameterDict
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Sep 25, 2019
1 parent ab2214b commit 1e442ce
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,14 @@ def reset_ctx(self, ctx):
for i in self.values():
i.reset_ctx(ctx)

def list_ctx(self):
"""Returns a list of all the contexts on which the underlying Parameters
are initialized."""
s = set()
for i in self.values():
s.update(i.list_ctx())
return list(s)

def setattr(self, name, value):
"""Set an attribute to a new value for all Parameters.
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ def test_parameter_dict():
mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())

# test reset_ctx
params3 = gluon.ParameterDict('net_')
params3.get('w0', shape=(10, 10))
params3.get('w1', shape=(10, 10))
params3.initialize(ctx=ctx)
list_contexts = [mx.cpu(42), mx.cpu(24)]
params3.reset_ctx(list_contexts)
# and test list_ctx
assert set(params3.list_ctx()) == set(list_contexts)


# test the dtype casting functionality
params0 = gluon.ParameterDict('')
params0.get('w0', shape=(10, 10), dtype='float32')
Expand Down

0 comments on commit 1e442ce

Please sign in to comment.