Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix dumps for Constant initializer #15150

Merged
merged 6 commits into from
Jul 19, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,12 @@ def __init__(self, value):
def _init_weight(self, _, arr):
arr[:] = self.value

def dumps(self):
val = self._kwargs['value']
if not np.isscalar(val):
self._kwargs['value'] = val.tolist() if type(val).__module__ == 'numpy' else val.asnumpy().tolist()
szha marked this conversation as resolved.
Show resolved Hide resolved
return json.dumps([self.__class__.__name__.lower(), self._kwargs])

@register
class Uniform(Initializer):
"""Initializes weights with random values uniformly sampled from a given range.
Expand Down
19 changes: 18 additions & 1 deletion tests/python/unittest/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mxnet as mx
import numpy as np
import json

def test_default_init():
data = mx.sym.Variable('data')
Expand Down Expand Up @@ -67,10 +68,26 @@ def test_bilinear_init():
bili_1d = np.array([[1/float(4), 3/float(4), 3/float(4), 1/float(4)]])
bili_2d = bili_1d * np.transpose(bili_1d)
assert (bili_2d == bili_weight.asnumpy()).all()


def test_const_init_dumps():
shape = tuple(np.random.randint(1, 10, size=np.random.randint(1, 5)))
# test NDArray input
init = mx.init.Constant(mx.nd.ones(shape))
val = init.dumps()
assert val == json.dumps([init.__class__.__name__.lower(), init._kwargs])
# test scalar input
init = mx.init.Constant(1)
assert init.dumps() == '["constant", {"value": 1}]'
# test numpy input
init = mx.init.Constant(np.ones(shape))
val = init.dumps()
assert val == json.dumps([init.__class__.__name__.lower(), init._kwargs])


if __name__ == '__main__':
test_variable_init()
test_default_init()
test_aux_init()
test_rsp_const_init()
test_bilinear_init()
test_const_init_dumps()