From 0ff4cde92db0a3b2804d5f7d8fd2f4c81d472c84 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 26 Nov 2024 23:13:15 +0100 Subject: [PATCH] Generalize code in 'test_gpuhello.py' Signed-off-by: Anatoly Myachev --- python/test/unit/instrumentation/test_gpuhello.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/test/unit/instrumentation/test_gpuhello.py b/python/test/unit/instrumentation/test_gpuhello.py index 413c3f642405..bdc6ca90742c 100644 --- a/python/test/unit/instrumentation/test_gpuhello.py +++ b/python/test/unit/instrumentation/test_gpuhello.py @@ -31,7 +31,6 @@ def kernel3(BLOCK_SIZE: tl.constexpr): def func(x: torch.Tensor, y: torch.Tensor): output = torch.empty_like(x) - assert x.is_cuda and y.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) kernel1[grid](BLOCK_SIZE=1024) @@ -39,10 +38,10 @@ def func(x: torch.Tensor, y: torch.Tensor): kernel3[grid](BLOCK_SIZE=1024) -def test_op(capfd): +def test_op(capfd, device: str): size = 98432 - x = torch.rand(size, device='cuda') - y = torch.rand(size, device='cuda') + x = torch.rand(size, device=device) + y = torch.rand(size, device=device) func(x, y) stdout, stderr = capfd.readouterr() if 'LLVM_PASS_PLUGIN_PATH' in os.environ: