Skip to content

Commit

Permalink
Add list_ctx to ParameterDict (apache#16185)
Browse files Browse the repository at this point in the history
* Add list_ctx to ParameterDict

Signed-off-by: Serge Panev <[email protected]>

* Add assert to test reset_ctx

Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L authored and drivanov committed Sep 26, 2019
1 parent 6d6d9de commit d3ce70b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,14 @@ def reset_ctx(self, ctx):
for i in self.values():
i.reset_ctx(ctx)

def list_ctx(self):
"""Returns a list of all the contexts on which the underlying Parameters
are initialized."""
s = set()
for i in self.values():
s.update(i.list_ctx())
return list(s)

def setattr(self, name, value):
"""Set an attribute to a new value for all Parameters.
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,20 @@ def test_parameter_dict():
mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())

# test reset_ctx
params3 = gluon.ParameterDict('net_')
params3.get('w0', shape=(10, 10))
params3.get('w1', shape=(10, 10))
params3.initialize(ctx=ctx)
list_contexts = [mx.cpu(42), mx.cpu(24)]
params3.reset_ctx(list_contexts)
for p in params3.values():
assert set(p.list_ctx()) == set(list_contexts)

# and test list_ctx
assert set(params3.list_ctx()) == set(list_contexts)


# test the dtype casting functionality
params0 = gluon.ParameterDict('')
params0.get('w0', shape=(10, 10), dtype='float32')
Expand Down

0 comments on commit d3ce70b

Please sign in to comment.