From 40063754cf39d2f1b4ea5aa8377b6a33a99aa327 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Tue, 21 May 2019 14:49:05 +0800 Subject: [PATCH] add test --- tests/python/unittest/test_operator.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ee94629485c8..7fbdd1ef1319 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6444,8 +6444,13 @@ def test_laop_6(): check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw, atol=atol_bw, dtype=dtype) - a = np.sqrt(np.arange(4 * 4)).reshape(4, 4) + ## det(I + dot(v, v.T)) = 1 + dot(v.T, v) >= 1, so it's always invertible; + ## det is away from zero, so the value of logdet is stable + v = np.random.random(4) + a = np.eye(4) + np.outer(v, v) a = np.tile(a, (3, 1, 1)) + + # test matrix inverse r = np.eye(4) r = np.tile(r, (3, 1, 1)) test_inverse = mx.sym.linalg.inverse(data) @@ -6453,6 +6458,23 @@ def test_laop_6(): check_fw(test_eye, [a], [r]) check_grad(test_inverse, [a]) + # test matrix determinant + # det + r = np.linalg.det(a) + test_det = mx.sym.linalg.det(data) + check_fw(test_det, [a], [r]) + check_grad(test_det, [a]) + # logdet + r = np.log(np.linalg.det(a)) + test_logdet = mx.sym.linalg.logdet(data) + check_fw(test_logdet, [a], [r]) + check_grad(test_logdet, [a]) + # test slogdet + r = np.log(np.abs(np.linalg.det(a))) + _, test_slogdet = mx.sym.linalg.slogdet(data) + check_fw(test_slogdet, [a], [r]) + check_grad(test_slogdet, [a]) + @with_seed() def test_stack(): for _ in range(100):