diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e8bfaba4736d..f4fc868f2f0c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5467,33 +5467,6 @@ def infer_shape(self, in_shape): def create_operator(self, ctx, shapes, dtypes): return Dot() -def _custom_exc3(seed): - def custom_exc3(): - def f(in_data, out_data): - out_data[0][:] = mx.nd.dot(in_data[0], in_data[1]) - out_data[0].wait_to_read() - _build_dot_custom(f, 'Dot3') - n = int(1e8) - a = mx.nd.zeros((n, 1)) - b = mx.nd.zeros((1, n)) - # trigger OOM - c = mx.nd.Custom(a, b, op_type='Dot3') - c.wait_to_read() - assert_raises(MXNetError, custom_exc3) - -def _custom_exc4(seed): - def custom_exc4(): - def f(in_data, out_data): - out_data[0][:] = mx.nd.dot(in_data[0], in_data[1]) - _build_dot_custom(f, 'Dot4') - n = int(1e8) - a = mx.nd.zeros((n, 1)) - b = mx.nd.zeros((1, n)) - # trigger OOM - c = mx.nd.Custom(a, b, op_type='Dot4') - c.wait_to_read() - assert_raises(MXNetError, custom_exc4) - @with_seed() def test_custom_op_exc(): # test except handling @@ -5523,8 +5496,35 @@ def f(in_data, out_data): assert_raises(MXNetError, custom_exc2) # 3. error in real execution - run_in_spawned_process(_custom_exc3, {}) - run_in_spawned_process(_custom_exc4, {}) + if default_context().device_type == 'cpu': + def custom_exc3(): + def f(in_data, out_data): + dot = mx.nd.dot(in_data[0], in_data[1]) + # input to Cholesky factorization should be + # symmetric positive-definite, error will be + # triggered in op execution on cpu + out_data[0][:] = mx.nd.linalg.potrf(dot) + out_data[0].wait_to_read() + _build_dot_custom(f, 'Dot3') + a = mx.nd.zeros((2, 1)) + b = mx.nd.zeros((1, 2)) + c = mx.nd.Custom(a, b, op_type='Dot3') + c.wait_to_read() + assert_raises(MXNetError, custom_exc3) + + def custom_exc4(): + def f(in_data, out_data): + dot = mx.nd.dot(in_data[0], in_data[1]) + # input to Cholesky factorization should be + # symmetric positive-definite, error will be + # triggered in op execution on cpu + out_data[0][:] = mx.nd.linalg.potrf(dot) + _build_dot_custom(f, 'Dot4') + a = mx.nd.zeros((2, 1)) + b = mx.nd.zeros((1, 2)) + c = mx.nd.Custom(a, b, op_type='Dot4') + c.wait_to_read() + assert_raises(MXNetError, custom_exc4) @with_seed()