From 0fd2f8dc160e094a68aa04e759599086cb6f7401 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Fri, 14 Jun 2019 14:26:36 -0700 Subject: [PATCH] fix for ch11 (#15244) --- python/mxnet/gluon/parameter.py | 2 +- python/mxnet/numpy_extension/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 86ee9ad4a55b..0797b4cf4d47 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -369,7 +369,7 @@ def _reduce(self): ctx = context.cpu() if self._stype == 'default': block = self.list_data() - data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block) + data = ndarray.add_n(*(w.copyto(ctx).as_nd_ndarray() for w in block)) / len(block) else: # fetch all rows for 'row_sparse' param all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx) diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index e2ccaa1dfc14..0e2d005df394 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -28,5 +28,6 @@ from ..util import use_np_shape, np_shape, is_np_shape from ..util import use_np_array, np_array, is_np_array from ..util import set_np, use_np, reset_np +from ..ndarray import waitall __all__ = []