Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix dumps for Constant initializer (#15150)
Browse files Browse the repository at this point in the history
* update dumps for const init

* add test

* fix for numpy input

* randomize test array shape and dim

* fix test

* replace type with isinstance
  • Loading branch information
abhinavs95 authored and szha committed Jul 19, 2019
1 parent eab6da6 commit da71324
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
6 changes: 6 additions & 0 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,12 @@ def __init__(self, value):
def _init_weight(self, _, arr):
arr[:] = self.value

def dumps(self):
val = self._kwargs['value']
if not np.isscalar(val):
self._kwargs['value'] = val.tolist() if isinstance(val, np.ndarray) else val.asnumpy().tolist()
return json.dumps([self.__class__.__name__.lower(), self._kwargs])

@register
class Uniform(Initializer):
"""Initializes weights with random values uniformly sampled from a given range.
Expand Down
19 changes: 18 additions & 1 deletion tests/python/unittest/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mxnet as mx
import numpy as np
import json

def test_default_init():
data = mx.sym.Variable('data')
Expand Down Expand Up @@ -67,10 +68,26 @@ def test_bilinear_init():
bili_1d = np.array([[1/float(4), 3/float(4), 3/float(4), 1/float(4)]])
bili_2d = bili_1d * np.transpose(bili_1d)
assert (bili_2d == bili_weight.asnumpy()).all()


def test_const_init_dumps():
shape = tuple(np.random.randint(1, 10, size=np.random.randint(1, 5)))
# test NDArray input
init = mx.init.Constant(mx.nd.ones(shape))
val = init.dumps()
assert val == json.dumps([init.__class__.__name__.lower(), init._kwargs])
# test scalar input
init = mx.init.Constant(1)
assert init.dumps() == '["constant", {"value": 1}]'
# test numpy input
init = mx.init.Constant(np.ones(shape))
val = init.dumps()
assert val == json.dumps([init.__class__.__name__.lower(), init._kwargs])


if __name__ == '__main__':
test_variable_init()
test_default_init()
test_aux_init()
test_rsp_const_init()
test_bilinear_init()
test_const_init_dumps()

0 comments on commit da71324

Please sign in to comment.