From 08895b71111e11eaf60280fa59aca93fb4f62193 Mon Sep 17 00:00:00 2001 From: Anirudh Date: Tue, 7 May 2019 05:10:33 -0700 Subject: [PATCH] Fix the return type of sparse.clip operator (#14856) * stype fix * ut * retrigger ci * Retrigger ci --- src/operator/tensor/matrix_op.cc | 2 +- tests/python/unittest/test_sparse_ndarray.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index b80c9a54510f..9e6bead7229c 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -769,7 +769,7 @@ parameter values: if (!dispatched && param.a_min <= 0.0 && param.a_max >= 0.0) { const int this_stype = (*in_attrs)[0]; if (this_stype != kUndefinedStorage) { - dispatched = storage_type_assign(&(*out_attrs)[0], kRowSparseStorage, + dispatched = storage_type_assign(&(*out_attrs)[0], mxnet::NDArrayStorageType(this_stype), dispatch_mode, DispatchMode::kFComputeEx); } } diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 7600ea944e83..3b4c684e8696 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -915,6 +915,7 @@ def check_fluent_regular(stype, func, kwargs, shape=(5, 17), equal_nan=False): check_fluent_regular('csr', 'slice', {'begin': (2, 5), 'end': (4, 7)}, shape=(5, 17)) check_fluent_regular('row_sparse', 'clip', {'a_min': -0.25, 'a_max': 0.75}) + check_fluent_regular('csr', 'clip', {'a_min': -0.25, 'a_max': 0.75}) for func in ['sum', 'mean', 'norm']: check_fluent_regular('csr', func, {'axis': 0})