diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 24970eaf9e5e..a44ba327b3a1 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -498,11 +498,11 @@ def test_relu(): def frelu(x): return np.maximum(x, 0.0) def frelu_grad(x): - return 1.0 * (x > 0.0) + return np.float32(1.0) * (x > np.float32(0.0)) shape = (3, 4) x = mx.symbol.Variable("x") y = mx.sym.relu(x) - xa = np.random.uniform(low=-1.0,high=1.0,size=shape) + xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype('float32') eps = 1e-4 # Avoid finite difference method inaccuracies due to discontinuous gradient at the origin. # Here we replace small problematic inputs with 1.0. Repro issue with seed 97264195.