diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 64a3a2c4eb81..0a6456e881f1 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5551,7 +5551,7 @@ def kernel(input): @pytest.mark.parametrize("dtype_str", ['float32', 'float64']) -def test_math_extern(dtype_str): +def test_math_extern(dtype_str, device): if is_interpreter(): pytest.skip('math_extern does not work in the interpreter mode') @@ -5575,8 +5575,8 @@ def kernel( x = numpy_random(shape, dtype_str=dtype_str, rs=rs) y_ref = np.tanh(x) - x_tri = to_triton(x, device='cuda') - y_tri = to_triton(numpy_random(shape, dtype_str=dtype_str, rs=rs), device='cuda') + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str=dtype_str, rs=rs), device=device) kernel[(1, )](x_tri, y_tri, shape[0], BLOCK_SIZE=shape[0]) # compare np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)