diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index c3338f4c766b..d3eeba1b504e 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -952,8 +952,9 @@ def update(self, index, weight, grad, state): grad = grad * self.rescale_grad if self.clip_gradient is not None: grad = clip(grad, -self.clip_gradient, self.clip_gradient) - weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr), - weight.shape, weight.context) + weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr), shape=weight.shape, + dtype=weight.dtype, ctx=weight.context) + @register # pylint: disable=invalid-name