Skip to content

Commit

Permalink
Update test case
Browse files Browse the repository at this point in the history
  • Loading branch information
juliusshufan committed Apr 18, 2019
1 parent 2c5c20c commit b1b6355
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,7 +2200,7 @@ def test_context_num_gpus():
assert mx.context.num_gpus() > 0

def math_log(shape, dtype, check_value):
np_x = np.random.rand(shape[0], shape[1])
np_x = np.random.rand(*tuple(shape))
x = mx.nd.array(np_x, dtype=dtype)
mx.nd.waitall()
y = mx.nd.log(data=x)
Expand All @@ -2213,7 +2213,7 @@ def math_log(shape, dtype, check_value):
assert_almost_equal(y.asnumpy(), y_.asnumpy())

def math_erf(shape, dtype, check_value):
np_x = np.random.rand(shape[0], shape[1])
np_x = np.random.rand(*tuple(shape))
x = mx.nd.array(np_x, dtype=dtype)
mx.nd.waitall()
y = mx.nd.erf(data=x)
Expand All @@ -2226,7 +2226,7 @@ def math_erf(shape, dtype, check_value):
assert_almost_equal(y.asnumpy(), y_.asnumpy())

def math_square(shape, dtype, check_value):
np_x = np.random.rand(shape[0], shape[1])
np_x = np.random.rand(*tuple(shape))
x = mx.nd.array(np_x, dtype=dtype)
mx.nd.waitall()
y = mx.nd.square(data=x)
Expand All @@ -2252,12 +2252,9 @@ def run_math(op, shape, dtype="float32", check_value=True):
def test_math():
ops = ['log', 'erf', 'square']
check_value= True
lshape = 1000
rshapes = [1, 10, 100, 1000, 10000]
shape_lst = [[1000], [100,1000], [10,100,100], [10,100,100,100]]
dtypes = ["float32", "float64"]
for rshape in rshapes:
shape = (lshape, rshape)
print("shape:(%d, %d), " % (lshape, rshape), end="")
for shape in shape_lst:
for dtype in dtypes:
for op in ops:
run_math(op, shape, dtype, check_value=check_value)
Expand Down

0 comments on commit b1b6355

Please sign in to comment.