diff --git a/src/common/utils.h b/src/common/utils.h index 5044f3447f84..1687a0909839 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -36,15 +36,29 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, namespace common { #if DMLC_USE_CXX11 +/* + * \brief Get input TBlobs from NDArrays, potentially performing cast_storage op and store + * temporary NDArrays in temps. If storage_fallback is false, + * MXNET_EXEC_STORAGE_FALLBACK env var determines whether storage type fallback is allowed. + */ template inline void GetInputBlobs(const std::vector& nds, std::vector *blobs, std::vector *temps, - const OpContext& ctx) { + const OpContext& ctx, + bool storage_fallback = false) { + if (storage_fallback == false) { + storage_fallback = dmlc::GetEnv("MXNET_EXEC_STORAGE_FALLBACK", true); + } for (auto& nd : nds) { if (nd.storage_type() != kDefaultStorage) { + if (storage_fallback == false) { + LOG(FATAL) << "Storage type conversion detected during execution. " + << "You are probably executing an operator which " + << "doesn't support NDArray inputs with non-default storage."; + } NDArray temp(nd.shape(), nd.ctx(), false); - op::CastStorageComputeEx({}, ctx, {nd}, {}, {temp}); + op::CastStorageComputeImpl(ctx.get_stream(), nd, temp); temps->push_back(temp); blobs->push_back(temp.data()); } else { diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 3e2caf15e19a..e8484346cc16 100755 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -327,7 +327,7 @@ void FCompExFallback(const nnvm::NodeAttrs& attrs, FCompute fcompute) { std::vector in_blobs, out_blobs; std::vector tmps; - common::GetInputBlobs(inputs, &in_blobs, &tmps, ctx); + common::GetInputBlobs(inputs, &in_blobs, &tmps, ctx, true); common::GetOutputBlobs(outputs, &out_blobs); fcompute(attrs, ctx, in_blobs, req, out_blobs); } diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 89323fe2dbfd..171b42f95160 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -6,6 +6,11 @@ from numpy.testing import assert_allclose import numpy.random as rnd +def assert_fcompex(f, *args, **kwargs): + prev_val = mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", "0", "1") + f(*args, **kwargs) + mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", prev_val) + def check_sparse_nd_elemwise_binary(shapes, storage_types, f, g): # generate inputs nds = [] @@ -27,11 +32,14 @@ def test_sparse_nd_elemwise_add(): op = mx.nd.elemwise_add for i in range(num_repeats): shape = [(rnd.randint(1, 10),rnd.randint(1, 10))] * 2 - check_sparse_nd_elemwise_binary(shape, ['default_storage'] * 2, op, g) - check_sparse_nd_elemwise_binary(shape, ['default_storage', 'row_sparse'], op, g) - check_sparse_nd_elemwise_binary(shape, ['row_sparse', 'row_sparse'], op, g) - -# Test a operator which doesn't implement FComputeEx + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['default_storage'] * 2, op, g) + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['default_storage', 'row_sparse'], op, g) + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['row_sparse', 'row_sparse'], op, g) + +# test a operator which doesn't implement FComputeEx def test_sparse_nd_elementwise_fallback(): num_repeats = 10 g = lambda x,y: x + y @@ -141,9 +149,9 @@ def check_sparse_nd_csr_slice(shape): if __name__ == '__main__': test_sparse_nd_zeros() + test_sparse_nd_elemwise_add() test_sparse_nd_elementwise_fallback() test_sparse_nd_copy() - test_sparse_nd_elemwise_add() test_sparse_nd_setitem() test_sparse_nd_basic() test_sparse_nd_slice()