diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index d7eee19252ae..160f17bbe8c6 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6118,15 +6118,6 @@ def test_laop(): if grad_check == 1: check_grad(test_sumlogdiag, [a]) - # test inverse - a = np.sqrt(np.arange(4 * 4)).reshape(4, 4) - a = rep_3x(a, 4, 4) - r = np.linalg.inv(a) - test_inverse = mx.sym.linalg.inverse(data1) - check_fw(test_inverse, [a], [r]) - if grad_check == 1: - check_grad(test_inverse, [a]) - # Tests for operators linalg.syrk, linalg.gelqf @@ -6424,6 +6415,34 @@ def test_laop_5(): check_symbolic_forward(test_trian, [data_in], [res_trian]) check_numeric_gradient(test_trian, [data_in]) +# Tests for linalg.inverse +@with_seed() +def test_laop_5(): + dtype = np.float64 + rtol_fw = 1e-7 + atol_fw = 1e-9 + num_eps = 1e-6 + rtol_bw = 1e-5 + atol_bw = 1e-6 + + data = mx.symbol.Variable('data') + + check_fw = lambda sym, location, expected:\ + check_symbolic_forward(sym, location, expected, rtol=rtol_fw, + atol=atol_fw, dtype=dtype) + check_grad = lambda sym, location:\ + check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw, + atol=atol_bw, dtype=dtype) + + a = np.sqrt(np.arange(4 * 4)).reshape(4, 4) + a = np.tile(a, (3, 1, 1)) + r = np.eye(4) + r = np.tile(r, (3, 1, 1)) + test_inverse = mx.sym.linalg.inverse(data) + test_eye = mx.sym.linalg.gemm2(data, test_inverse) + check_fw(test_eye, [a], [r]) + check_grad(test_inverse, [a]) + @with_seed() def test_stack(): for _ in range(100):