Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Large Vector tests for DGL Ops Part 2 (#16497)
Browse files Browse the repository at this point in the history
* add hyperbolic, logical, sign and regression tests for large vector

* changed hyperbolic functions into existing trignometric functions

* fix trigo and simple bind needs shape as tuple

* fix logical ops, add with_seed

* fix arcosh in largearray, remove regression from largevector
  • Loading branch information
ChaiBapchya authored and anirudh2290 committed Oct 24, 2019
1 parent c3395ca commit 0742a9b
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,10 +1415,10 @@ def check_arcsinh():
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_arccosh():
x = create_input_for_trigonometric_ops([1, np.pi/2, 3*np.pi/4, np.pi])
x = create_input_for_trigonometric_ops([1, np.pi/2, 3*np.pi/4, np.pi, 5*np.pi/4])
y = nd.arccosh(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying arccosh()
expected_output = [0, np.arccosh(np.pi/2), np.arccosh(3*np.pi/4), np.arccosh(np.pi)]
expected_output = [0, np.arccosh(np.pi/2), np.arccosh(3*np.pi/4), np.arccosh(np.pi), np.arccosh(5*np.pi/4)]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_arctanh():
Expand Down
85 changes: 82 additions & 3 deletions tests/nightly/test_large_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,8 @@ def test_concat():
a = nd.ones(LARGE_X)
b = nd.zeros(LARGE_X)
c = nd.concat(a, b, dim=0)
assert c[0][0] == 1
assert c[-1][-1] == 0
assert c[0] == 1
assert c[-1] == 0
assert c.shape[0] == (2 * LARGE_X)


Expand Down Expand Up @@ -710,6 +710,37 @@ def test_full():
assert a[-1] == 3


def test_sign():
a = mx.nd.random.normal(-1, 1, shape=LARGE_X)
mx_res = mx.nd.sign(a)
assert_almost_equal(mx_res[-1].asnumpy(), np.sign(a[-1].asnumpy()))


def test_logical():
def check_logical_and(a, b):
mx_res = mx.nd.logical_and(a, b)
assert_almost_equal(mx_res[-1].asnumpy(), np.logical_and(a[-1].asnumpy(), b[-1].asnumpy()))

def check_logical_or(a, b):
mx_res = mx.nd.logical_or(a, b)
assert_almost_equal(mx_res[-1].asnumpy(), np.logical_or(a[-1].asnumpy(), b[-1].asnumpy()))

def check_logical_not(a, b):
mx_res = mx.nd.logical_not(a, b)
assert_almost_equal(mx_res[-1].asnumpy(), np.logical_not(a[-1].asnumpy(), b[-1].asnumpy()))

def check_logical_xor(a, b):
mx_res = mx.nd.logical_xor(a, b)
assert_almost_equal(mx_res[-1].asnumpy(), np.logical_xor(a[-1].asnumpy(), b[-1].asnumpy()))

a = mx.nd.ones(LARGE_X)
b = mx.nd.zeros(LARGE_X)
check_logical_and(a, b)
check_logical_or(a, b)
check_logical_not(a, b)
check_logical_xor(a, b)


def test_astype():
x = create_vector(size=LARGE_X//4)
x = nd.tile(x, 4)
Expand Down Expand Up @@ -752,7 +783,7 @@ def assert_correctness_of_rounding_ops(output, mid, expected_vals):

def test_rounding_ops():
x = create_input_for_rounding_ops()

def check_ceil():
y = nd.ceil(x)
# expected ouput for middle 5 values after applying ceil()
Expand Down Expand Up @@ -854,6 +885,48 @@ def check_tan():
expected_output = [-.577, -1, 0, 1, .577]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_arcsinh():
x = create_input_for_trigonometric_ops([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2])
y = nd.arcsinh(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying arcsinh()
expected_output = [np.arcsinh(-np.pi/2), np.arcsinh(-np.pi/4), 0, np.arcsinh(np.pi/4), np.arcsinh(np.pi/2)]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_arccosh():
x = create_input_for_trigonometric_ops([1, np.pi/2, 3*np.pi/4, np.pi, 5*np.pi/4])
y = nd.arccosh(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying arccosh()
expected_output = [0, np.arccosh(np.pi/2), np.arccosh(3*np.pi/4), np.arccosh(np.pi), np.arccosh(5*np.pi/4)]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_arctanh():
x = create_input_for_trigonometric_ops([-1/4, -1/2, 0, 1/4, 1/2])
y = nd.arctanh(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying arctanh()
expected_output = [np.arctanh(-1/4), np.arctanh(-1/2), 0, np.arctanh(1/4), np.arctanh(1/2)]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_sinh():
x = create_input_for_trigonometric_ops([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2])
y = nd.sinh(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying sinh()
expected_output = [np.sinh(-np.pi/2), np.sinh(-np.pi/4), 0, np.sinh(np.pi/4), np.sinh(np.pi/2)]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_cosh():
x = create_input_for_trigonometric_ops([0, 1, np.pi/2, 3*np.pi/4, np.pi])
y = nd.cosh(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying cosh()
expected_output = [1, np.cosh(1), np.cosh(np.pi/2), np.cosh(3*np.pi/4), np.cosh(np.pi)]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_tanh():
x = create_input_for_trigonometric_ops([-1/4, -1/2, 0, 1/4, 1/2])
y = nd.tanh(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying tanh()
expected_output = [np.tanh(-1/4), np.tanh(-1/2), 0, np.tanh(1/4), np.tanh(1/2)]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_radians():
x = create_input_for_trigonometric_ops([0, 90, 180, 270, 360])
y = nd.radians(x)
Expand All @@ -874,6 +947,12 @@ def check_degrees():
check_sin()
check_cos()
check_tan()
check_arcsinh()
check_arccosh()
check_arctanh()
check_sinh()
check_cosh()
check_tanh()
check_radians()
check_degrees()

Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2931,6 +2931,7 @@ def test_big_transpose():
assert_allclose(x_np, z.asnumpy().astype('uint8'))


@with_seed()
def test_larger_transpose():
x = mx.nd.random.normal(shape=(50,51))
y = mx.nd.transpose(x)
Expand Down

0 comments on commit 0742a9b

Please sign in to comment.