Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add Large Dim Checks for linalg Operators (#18816)
Browse files Browse the repository at this point in the history
* initial

* test

* gemm and gemm2

* type fix

* syrk trmm trsm

* gelqf

* move tests from test_large_array.py to test_large_vector.py

* fix white space issue

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
Zha0q1 and Ubuntu committed Jul 31, 2020
1 parent 84c9e0d commit f4e62df
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
23 changes: 23 additions & 0 deletions src/operator/tensor/la_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ struct LaTrianParam : public dmlc::Parameter<LaTrianParam> {
}
};

// check if any dim will overflow 32-bit int
inline void check_large_dim(std::vector<dim_t> dims) {
for (dim_t dim : dims) {
CHECK_LE(dim, INT_MAX)
<< "Large matrix dimensions (>= 2^31) are not supported";
}
}

// Common function for shape inference for matrix mult and matrix mac.
inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
Expand All @@ -181,6 +189,11 @@ inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
const int ndim((*in_attrs)[0].ndim()), axis(axis_param < 0 ? ndim + axis_param : axis_param);
CHECK(axis >= 0 && axis < ndim-1)
<< "Invalid row axis (" << axis_param << ")";
// Check if input matrix dims are too large
check_large_dim({(*in_attrs)[0][axis],
(*in_attrs)[0][ndim-1],
(*in_attrs)[1][axis],
(*in_attrs)[1][ndim-1]});
std::vector<int> oshape(ndim);
for ( int i = 0; i < ndim-1; ++i ) {
if (i != axis) {
Expand Down Expand Up @@ -225,6 +238,10 @@ inline bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs,
<< "Shapes of inputs 0, 1 must be the same, except on last two dimensions";
oshape[i] = (*in_attrs)[0][i];
}
// Check if the input matrix dims are too large; it suffices to check the second
// input only because the first is square whose size is bounded by memory
check_large_dim({(*in_attrs)[1][ndim-1],
(*in_attrs)[1][ndim-2]});
if ( param.rightside ) {
// We compute B * A where A is the first and B the second input.
CHECK_EQ((*in_attrs)[0][ndim-2], (*in_attrs)[1][ndim-1])
Expand Down Expand Up @@ -341,6 +358,9 @@ inline bool LaSyrkShape(const nnvm::NodeAttrs& attrs,
bool transpose = nnvm::get<LaSyrkParam>(attrs.parsed).transpose;
const int ndim = in_attr.ndim();
if ( ndim >= 2 ) {
// Check if input matrix dims are too large
check_large_dim({in_attr[ndim-1],
in_attr[ndim-2]});
// Forward shape inference.
std::vector<int> oshape(ndim);
for ( int i = 0; i < ndim-2; ++i ) {
Expand Down Expand Up @@ -371,6 +391,9 @@ inline bool LaLQFactShape(const nnvm::NodeAttrs& attrs,
const int ndim(in_a.ndim());
CHECK_LE(in_a[ndim-2], in_a[ndim-1])
<< "Input A shape wrong: Last dimension must be >= than second to last";
// Check if the last dimension is too large; it suffices to check the last dim
// only since the second to last dim <= last dim
check_large_dim({in_a[ndim-1]});
// Q must have same shape as A
SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_a);
std::vector<int> oshape_l(ndim);
Expand Down
51 changes: 50 additions & 1 deletion tests/nightly/test_large_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, create_vector
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed
from tests.python.unittest.common import with_seed, assertRaises
from mxnet.base import MXNetError
from nose.tools import with_setup
import unittest

# dimension constants
LARGE_X = 4300000000
MEDIUM_X = 1000000000
INT32_MAX = 2**31-1


def test_nn():
Expand Down Expand Up @@ -1064,6 +1066,53 @@ def check_minimum():
check_minimum()


# openblas and cublas are known to not work well with large
# matrix dims under current configuration. checks are added
# to exit from such use cases
def test_linalg_large_dim():
def check_gemm():
A = nd.ones(shape=(1, INT32_MAX + 1, 1))
B = nd.ones(shape=(1, INT32_MAX + 1, 1))
C = nd.ones(shape=(1, 1, 1))
assertRaises(MXNetError, nd.linalg.gemm, \
A, B, C, transpose_b=True)

def check_gemm2():
A = nd.ones(shape=(1, 1, INT32_MAX + 1))
B = nd.ones(shape=(1, 1, INT32_MAX + 1))
assertRaises(MXNetError, nd.linalg.gemm2, \
A, B, transpose_b=True)

def check_trmm():
A = nd.ones(shape=(1, 1, 1))
B = nd.ones(shape=(1, INT32_MAX + 1, 1))
assertRaises(MXNetError, nd.linalg.trmm, \
A, B, rightside=True)

def check_trsm():
A = nd.ones(shape=(1, 1, 1))
B = nd.ones(shape=(1, 1, INT32_MAX + 1))
assertRaises(MXNetError, nd.linalg.trsm, \
A, B, rightside=False)

def check_syrk():
A = nd.ones(shape=(1, INT32_MAX + 1, 1))
assertRaises(MXNetError, nd.linalg.syrk, A)
assertRaises(MXNetError, nd.linalg.syrk, A, transpose=True)

def check_gelqf():
A = nd.ones(shape=(1, 1, INT32_MAX + 1))
assertRaises(MXNetError, nd.linalg.gelqf, A)

# batch input
check_gemm()
check_gemm2()
check_trmm()
check_trsm()
check_syrk()
check_gelqf()


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit f4e62df

Please sign in to comment.