diff --git a/python/test/unit/instrumentation/test_gpuhello.py b/python/test/unit/instrumentation/test_gpuhello.py index 413c3f642405..bdc6ca90742c 100644 --- a/python/test/unit/instrumentation/test_gpuhello.py +++ b/python/test/unit/instrumentation/test_gpuhello.py @@ -31,7 +31,6 @@ def kernel3(BLOCK_SIZE: tl.constexpr): def func(x: torch.Tensor, y: torch.Tensor): output = torch.empty_like(x) - assert x.is_cuda and y.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) kernel1[grid](BLOCK_SIZE=1024) @@ -39,10 +38,10 @@ def func(x: torch.Tensor, y: torch.Tensor): kernel3[grid](BLOCK_SIZE=1024) -def test_op(capfd): +def test_op(capfd, device: str): size = 98432 - x = torch.rand(size, device='cuda') - y = torch.rand(size, device='cuda') + x = torch.rand(size, device=device) + y = torch.rand(size, device=device) func(x, y) stdout, stderr = capfd.readouterr() if 'LLVM_PASS_PLUGIN_PATH' in os.environ: