From 0e2d338358badd71f85f72a9ae06c7df5e10e776 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Wed, 21 Oct 2020 04:50:53 +0000 Subject: [PATCH] backport #19393 to v1.x --- python/mxnet/gluon/block.py | 2 +- python/mxnet/gluon/parameter.py | 9 ++++++--- src/c_api/c_api.cc | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 8282c93a6f6d..850ce1fffb6d 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -999,7 +999,7 @@ def _build_cache(self, *args): 'added to the parameter dicts.\n' 'Please check the backend.') - param = Parameter(name) + param = Parameter(name, dtype=param_data.dtype) param._load_init(param_data, args[0].context) pair = (False, param) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 55b0f4a963a1..89456cb07f7b 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -389,10 +389,13 @@ def _reduce(self): ctx = context.cpu() if self._stype == 'default': block = self.list_data() - if is_np_array(): - data = sum([w.copyto(ctx) for w in block]) / len(block) + if len(block) > 1: + if is_np_array(): + data = sum([w.copyto(ctx) for w in block]) / len(block) + else: + data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block) else: - data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block) + data = self.data().copyto(ctx) else: # fetch all rows for 'row_sparse' param all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 5e91bccde99e..26c8ab201389 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1416,7 +1416,7 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // this temp workspace holds memory allocated by custom library via OpResource auto ndarray_alloc = [&](const mxnet::TShape &shape, Context ctx, int dtype, std::string name, bool isArg) { - NDArray* arr = new NDArray(shape, ctx, dtype); + NDArray* arr = new NDArray(shape, ctx, false, dtype); if (isArg) { new_args.push_back(arr); new_arg_names.push_back(name);