Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor test
Browse files Browse the repository at this point in the history
  • Loading branch information
arcadiaphy committed May 16, 2019
1 parent 808104d commit e728d22
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6118,15 +6118,6 @@ def test_laop():
if grad_check == 1:
check_grad(test_sumlogdiag, [a])

# test inverse
a = np.sqrt(np.arange(4 * 4)).reshape(4, 4)
a = rep_3x(a, 4, 4)
r = np.linalg.inv(a)
test_inverse = mx.sym.linalg.inverse(data1)
check_fw(test_inverse, [a], [r])
if grad_check == 1:
check_grad(test_inverse, [a])


# Tests for operators linalg.syrk, linalg.gelqf

Expand Down Expand Up @@ -6424,6 +6415,34 @@ def test_laop_5():
check_symbolic_forward(test_trian, [data_in], [res_trian])
check_numeric_gradient(test_trian, [data_in])

# Tests for linalg.inverse
@with_seed()
def test_laop_5():
dtype = np.float64
rtol_fw = 1e-7
atol_fw = 1e-9
num_eps = 1e-6
rtol_bw = 1e-5
atol_bw = 1e-6

data = mx.symbol.Variable('data')

check_fw = lambda sym, location, expected:\
check_symbolic_forward(sym, location, expected, rtol=rtol_fw,
atol=atol_fw, dtype=dtype)
check_grad = lambda sym, location:\
check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw,
atol=atol_bw, dtype=dtype)

a = np.sqrt(np.arange(4 * 4)).reshape(4, 4)
a = np.tile(a, (3, 1, 1))
r = np.eye(4)
r = np.tile(r, (3, 1, 1))
test_inverse = mx.sym.linalg.inverse(data)
test_eye = mx.sym.linalg.gemm2(data, test_inverse)
check_fw(test_eye, [a], [r])
check_grad(test_inverse, [a])

@with_seed()
def test_stack():
for _ in range(100):
Expand Down

0 comments on commit e728d22

Please sign in to comment.