From b9cd7ef9aad31699633e48507a150908f895db04 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 25 Dec 2019 08:44:16 +0000 Subject: [PATCH] fix norm sparse fallback --- src/operator/tensor/broadcast_reduce_norm_value.cc | 2 +- src/operator/tensor/broadcast_reduce_norm_value.cu | 2 +- src/operator/tensor/broadcast_reduce_op.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/tensor/broadcast_reduce_norm_value.cc b/src/operator/tensor/broadcast_reduce_norm_value.cc index 4cd92d44997e..9acc157f8eca 100644 --- a/src/operator/tensor/broadcast_reduce_norm_value.cc +++ b/src/operator/tensor/broadcast_reduce_norm_value.cc @@ -40,7 +40,7 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs, const NormParam& param = nnvm::get(attrs.parsed); mshadow::Stream* s = ctx.get_stream(); const NDArrayStorageType istype = inputs[0].storage_type(); - const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape(); + const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape(0, -1); if ((istype == kRowSparseStorage || istype == kCSRStorage) && axis.ndim() == 0 && param.ord == 2) { // l2 norm on the entire array diff --git a/src/operator/tensor/broadcast_reduce_norm_value.cu b/src/operator/tensor/broadcast_reduce_norm_value.cu index 188c93e61221..735c3d7faec9 100644 --- a/src/operator/tensor/broadcast_reduce_norm_value.cu +++ b/src/operator/tensor/broadcast_reduce_norm_value.cu @@ -39,7 +39,7 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs, const NormParam& param = nnvm::get(attrs.parsed); mshadow::Stream* s = ctx.get_stream(); const NDArrayStorageType istype = inputs[0].storage_type(); - const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape(); + const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape(0, -1); if ((istype == kRowSparseStorage || istype == kCSRStorage) && axis.ndim() == 0 && param.ord == 2) { // l2 norm on the entire array diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 27e22491ca35..799f86544160 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1152,7 +1152,7 @@ inline bool LpNormStorageType(const nnvm::NodeAttrs& attrs, DispatchMode::kFCompute); } if (param.ord == 2) { - const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape(); + const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape(0, -1); if (!dispatched && (in_stype == kRowSparseStorage || in_stype == kCSRStorage) && axis.ndim() == 0 && param.ord == 2) { // l2 norm: rsp/csr, axis = () -> dns