diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index aca7c58707e2..277bfd0f4fa5 100755 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -464,6 +464,12 @@ def __init__(self, value): def _init_weight(self, _, arr): arr[:] = self.value + def dumps(self): + val = self._kwargs['value'] + if not np.isscalar(val): + self._kwargs['value'] = val.tolist() if isinstance(val, np.ndarray) else val.asnumpy().tolist() + return json.dumps([self.__class__.__name__.lower(), self._kwargs]) + @register class Uniform(Initializer): """Initializes weights with random values uniformly sampled from a given range. diff --git a/tests/python/unittest/test_init.py b/tests/python/unittest/test_init.py index c8bf01f48ca3..6d8830c1d089 100644 --- a/tests/python/unittest/test_init.py +++ b/tests/python/unittest/test_init.py @@ -17,6 +17,7 @@ import mxnet as mx import numpy as np +import json def test_default_init(): data = mx.sym.Variable('data') @@ -67,10 +68,26 @@ def test_bilinear_init(): bili_1d = np.array([[1/float(4), 3/float(4), 3/float(4), 1/float(4)]]) bili_2d = bili_1d * np.transpose(bili_1d) assert (bili_2d == bili_weight.asnumpy()).all() - + +def test_const_init_dumps(): + shape = tuple(np.random.randint(1, 10, size=np.random.randint(1, 5))) + # test NDArray input + init = mx.init.Constant(mx.nd.ones(shape)) + val = init.dumps() + assert val == json.dumps([init.__class__.__name__.lower(), init._kwargs]) + # test scalar input + init = mx.init.Constant(1) + assert init.dumps() == '["constant", {"value": 1}]' + # test numpy input + init = mx.init.Constant(np.ones(shape)) + val = init.dumps() + assert val == json.dumps([init.__class__.__name__.lower(), init._kwargs]) + + if __name__ == '__main__': test_variable_init() test_default_init() test_aux_init() test_rsp_const_init() test_bilinear_init() + test_const_init_dumps()