Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
arcadiaphy committed May 21, 2019
1 parent ba8198c commit 4006375
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6444,15 +6444,37 @@ def test_laop_6():
check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw,
atol=atol_bw, dtype=dtype)

a = np.sqrt(np.arange(4 * 4)).reshape(4, 4)
## det(I + dot(v, v.T)) = 1 + dot(v.T, v) >= 1, so it's always invertible;
## det is away from zero, so the value of logdet is stable
v = np.random.random(4)
a = np.eye(4) + np.outer(v, v)
a = np.tile(a, (3, 1, 1))

# test matrix inverse
r = np.eye(4)
r = np.tile(r, (3, 1, 1))
test_inverse = mx.sym.linalg.inverse(data)
test_eye = mx.sym.linalg.gemm2(data, test_inverse)
check_fw(test_eye, [a], [r])
check_grad(test_inverse, [a])

# test matrix determinant
# det
r = np.linalg.det(a)
test_det = mx.sym.linalg.det(data)
check_fw(test_det, [a], [r])
check_grad(test_det, [a])
# logdet
r = np.log(np.linalg.det(a))
test_logdet = mx.sym.linalg.logdet(data)
check_fw(test_logdet, [a], [r])
check_grad(test_logdet, [a])
# test slogdet
r = np.log(np.abs(np.linalg.det(a)))
_, test_slogdet = mx.sym.linalg.slogdet(data)
check_fw(test_slogdet, [a], [r])
check_grad(test_slogdet, [a])

@with_seed()
def test_stack():
for _ in range(100):
Expand Down

0 comments on commit 4006375

Please sign in to comment.