From 9ca57b4a966da601d7d61d7c95ddb55147fc35b1 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Wed, 21 Oct 2020 04:38:26 +0000 Subject: [PATCH] backport #19393 to v1.x --- python/mxnet/gluon/block.py | 2 +- python/mxnet/gluon/parameter.py | 9 ++++++--- src/c_api/c_api.cc | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index a89d4bc13f13..765790c1a0da 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1072,7 +1072,7 @@ def _build_cache(self, *args): 'added to the parameter dicts.\n' 'Please check the backend.') - param = Parameter(name) + param = Parameter(name, dtype=param_data.dtype) param._var_name = name serialization_name = name # HybridBlock.export param._load_init(param_data, args[0].context) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 6f66947d2c10..26d061a703fe 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -400,10 +400,13 @@ def _reduce(self): ctx = context.cpu() if self._stype == 'default': block = self.list_data() - if is_np_array(): - data = sum([w.copyto(ctx) for w in block]) / len(block) + if len(block) > 1: + if is_np_array(): + data = sum([w.copyto(ctx) for w in block]) / len(block) + else: + data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block) else: - data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block) + data = self.data().copyto(ctx) else: # fetch all rows for 'row_sparse' param all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 36678636391c..3c793dd01c7f 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1417,7 +1417,7 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // this temp workspace holds memory allocated by custom library via OpResource auto ndarray_alloc = [&](const mxnet::TShape &shape, Context ctx, int dtype, std::string name, bool isArg) { - NDArray* arr = new NDArray(shape, ctx, dtype); + NDArray* arr = new NDArray(shape, ctx, false, dtype); if (isArg) { new_args.push_back(arr); new_arg_names.push_back(name);