diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index cd097781243b..147b173f2f66 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -157,6 +157,14 @@ struct LaTrianParam : public dmlc::Parameter { } }; +// check if any dim will overflow 32-bit int +inline void check_large_dim(std::vector dims) { + for (dim_t dim : dims) { + CHECK_LE(dim, INT_MAX) + << "Large matrix dimensions (>= 2^31) are not supported"; + } +} + // Common function for shape inference for matrix mult and matrix mac. inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_attrs, @@ -181,6 +189,11 @@ inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs, const int ndim((*in_attrs)[0].ndim()), axis(axis_param < 0 ? ndim + axis_param : axis_param); CHECK(axis >= 0 && axis < ndim-1) << "Invalid row axis (" << axis_param << ")"; + // Check if input matrix dims are too large + check_large_dim({(*in_attrs)[0][axis], + (*in_attrs)[0][ndim-1], + (*in_attrs)[1][axis], + (*in_attrs)[1][ndim-1]}); std::vector oshape(ndim); for ( int i = 0; i < ndim-1; ++i ) { if (i != axis) { @@ -225,6 +238,10 @@ inline bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs, << "Shapes of inputs 0, 1 must be the same, except on last two dimensions"; oshape[i] = (*in_attrs)[0][i]; } + // Check if the input matrix dims are too large; it suffices to check the second + // input only because the first is square whose size is bounded by memory + check_large_dim({(*in_attrs)[1][ndim-1], + (*in_attrs)[1][ndim-2]}); if ( param.rightside ) { // We compute B * A where A is the first and B the second input. CHECK_EQ((*in_attrs)[0][ndim-2], (*in_attrs)[1][ndim-1]) @@ -341,6 +358,9 @@ inline bool LaSyrkShape(const nnvm::NodeAttrs& attrs, bool transpose = nnvm::get(attrs.parsed).transpose; const int ndim = in_attr.ndim(); if ( ndim >= 2 ) { + // Check if input matrix dims are too large + check_large_dim({in_attr[ndim-1], + in_attr[ndim-2]}); // Forward shape inference. std::vector oshape(ndim); for ( int i = 0; i < ndim-2; ++i ) { @@ -371,6 +391,9 @@ inline bool LaLQFactShape(const nnvm::NodeAttrs& attrs, const int ndim(in_a.ndim()); CHECK_LE(in_a[ndim-2], in_a[ndim-1]) << "Input A shape wrong: Last dimension must be >= than second to last"; + // Check if the last dimension is too large; it suffices to check the last dim + // only since the second to last dim <= last dim + check_large_dim({in_a[ndim-1]}); // Q must have same shape as A SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_a); std::vector oshape_l(ndim); diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index bbad75627769..d4365391cf4e 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -27,13 +27,15 @@ from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, create_vector from mxnet import gluon, nd -from tests.python.unittest.common import with_seed +from tests.python.unittest.common import with_seed, assertRaises +from mxnet.base import MXNetError from nose.tools import with_setup import unittest # dimension constants LARGE_X = 4300000000 MEDIUM_X = 1000000000 +INT32_MAX = 2**31-1 def test_nn(): @@ -1064,6 +1066,53 @@ def check_minimum(): check_minimum() +# openblas and cublas are known to not work well with large +# matrix dims under current configuration. checks are added +# to exit from such use cases +def test_linalg_large_dim(): + def check_gemm(): + A = nd.ones(shape=(1, INT32_MAX + 1, 1)) + B = nd.ones(shape=(1, INT32_MAX + 1, 1)) + C = nd.ones(shape=(1, 1, 1)) + assertRaises(MXNetError, nd.linalg.gemm, \ + A, B, C, transpose_b=True) + + def check_gemm2(): + A = nd.ones(shape=(1, 1, INT32_MAX + 1)) + B = nd.ones(shape=(1, 1, INT32_MAX + 1)) + assertRaises(MXNetError, nd.linalg.gemm2, \ + A, B, transpose_b=True) + + def check_trmm(): + A = nd.ones(shape=(1, 1, 1)) + B = nd.ones(shape=(1, INT32_MAX + 1, 1)) + assertRaises(MXNetError, nd.linalg.trmm, \ + A, B, rightside=True) + + def check_trsm(): + A = nd.ones(shape=(1, 1, 1)) + B = nd.ones(shape=(1, 1, INT32_MAX + 1)) + assertRaises(MXNetError, nd.linalg.trsm, \ + A, B, rightside=False) + + def check_syrk(): + A = nd.ones(shape=(1, INT32_MAX + 1, 1)) + assertRaises(MXNetError, nd.linalg.syrk, A) + assertRaises(MXNetError, nd.linalg.syrk, A, transpose=True) + + def check_gelqf(): + A = nd.ones(shape=(1, 1, INT32_MAX + 1)) + assertRaises(MXNetError, nd.linalg.gelqf, A) + + # batch input + check_gemm() + check_gemm2() + check_trmm() + check_trsm() + check_syrk() + check_gelqf() + + if __name__ == '__main__': import nose nose.runmodule()