diff --git a/tests/lora/test_punica_ops.py b/tests/lora/test_punica_ops.py index 963260367671..82db7fece3f9 100644 --- a/tests/lora/test_punica_ops.py +++ b/tests/lora/test_punica_ops.py @@ -395,6 +395,7 @@ def test_kernels( Tests LoRA kernels. """ torch.set_default_device(device) + torch.cuda.set_device(device) set_random_seed(seed) if op_type == "shrink": @@ -447,6 +448,7 @@ def test_kernels_hidden_size( Tests SGMV and LoRA kernels. """ torch.set_default_device(device) + torch.cuda.set_device(device) set_random_seed(seed) if op_type == "shrink":