diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 2446107ad466..7fa570af4923 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -485,6 +485,14 @@ def test_dot(): assert_almost_equal(c, C.asnumpy(), atol=atol) +@raises(mx.base.MXNetError) +def test_gemm_overflow(): + # 100 * 6000 * 7000 overflows signed int32 + a = mx.nd.random.uniform(shape=(100, 6000, 1)) + b = mx.nd.random.uniform(shape=(100, 1, 7000)) + c = mx.nd.batch_dot(a, b) + + @with_seed() def test_reduce(): sample_num = 200