Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add Large Dim Checks for linalg Operators #18816

Merged
merged 12 commits into from
Jul 31, 2020
23 changes: 23 additions & 0 deletions src/operator/tensor/la_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ struct LaTrianParam : public dmlc::Parameter<LaTrianParam> {
}
};

// check if any dim will overflow 32-bit int
Zha0q1 marked this conversation as resolved.
Show resolved Hide resolved
inline void check_large_dim(std::vector<dim_t> dims) {
for (dim_t dim : dims) {
CHECK_LE(dim, INT_MAX)
<< "Large matrix dimensions (>= 2^31) are not supported";
}
}

// Common function for shape inference for matrix mult and matrix mac.
inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
Expand All @@ -181,6 +189,21 @@ inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
const int ndim((*in_attrs)[0].ndim()), axis(axis_param < 0 ? ndim + axis_param : axis_param);
CHECK(axis >= 0 && axis < ndim-1)
<< "Invalid row axis (" << axis_param << ")";
// check if any dim is too large
check_large_dim({(*in_attrs)[0][axis],
(*in_attrs)[0][ndim-1],
(*in_attrs)[1][axis],
(*in_attrs)[1][ndim-1]});
/*
CHECK_LE((*in_attrs)[0][axis], INT_MAX)
<< "Large matrix dimensions (>= 2^31) are not supported";
CHECK_LE((*in_attrs)[0][ndim-1], INT_MAX)
<< "Large matrix dimensions (>= 2^31) are not supported";;
CHECK_LE((*in_attrs)[1][axis], INT_MAX)
<< "Large matrix dimensions (>= 2^31) are not supported";;
CHECK_LE((*in_attrs)[1][ndim-1], INT_MAX)
<< "Large matrix dimensions (>= 2^31) are not supported";;
Zha0q1 marked this conversation as resolved.
Show resolved Hide resolved
*/
std::vector<int> oshape(ndim);
for ( int i = 0; i < ndim-1; ++i ) {
if (i != axis) {
Expand Down
21 changes: 20 additions & 1 deletion tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor, get_identity_mat, get_identity_mat_batch
from mxnet import gluon, nd
from common import with_seed, with_post_test_cleanup
from common import with_seed, with_post_test_cleanup, assertRaises
from mxnet.base import MXNetError
from nose.tools import with_setup
import unittest

Expand Down Expand Up @@ -1350,6 +1351,24 @@ def run_trsm(inp):
check_batch_trsm()


def test_linalg_large_dim():
def check_gemm():
A = mx.nd.ones(shape=(1, 2**32, 1))
B = mx.nd.ones(shape=(1, 2**32, 1))
C = mx.nd.ones(shape=(1, 1, 1))
assertRaises(MXNetError, mx.nd.linalg.gemm, \
A, B, C, transpose_b=True, alpha=1.0 , beta=1.0)

def check_gemm2():
A = mx.nd.ones(shape=(1, 1, 2**32))
B = mx.nd.ones(shape=(1, 1, 2**32))
assertRaises(MXNetError, mx.nd.linalg.gemm2, \
A, B, transpose_b=True, alpha=1.0)

check_gemm()
check_gemm2()


def test_basic():
def check_elementwise():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
Expand Down