diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py index 2465bc6071..a95ab43da5 100644 --- a/tests/e2e/offline_inference/test_diffusion_lora.py +++ b/tests/e2e/offline_inference/test_diffusion_lora.py @@ -9,6 +9,7 @@ from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform # ruff: noqa: E402 REPO_ROOT = Path(__file__).resolve().parents[2] @@ -89,7 +90,7 @@ def _write_zimage_lora(adapter_dir: Path) -> str: width=width, num_inference_steps=2, guidance_scale=0.0, - generator=torch.Generator("cuda").manual_seed(42), + generator=torch.Generator(current_omni_platform.device_type).manual_seed(42), num_outputs_per_prompt=1, ), ) @@ -119,7 +120,7 @@ def _write_zimage_lora(adapter_dir: Path) -> str: width=width, num_inference_steps=2, guidance_scale=0.0, - generator=torch.Generator("cuda").manual_seed(42), + generator=torch.Generator(current_omni_platform.device_type).manual_seed(42), num_outputs_per_prompt=1, lora_request=lora_request, lora_scale=2.0, diff --git a/tests/e2e/offline_inference/test_stable_audio_model.py b/tests/e2e/offline_inference/test_stable_audio_model.py index ad879d7517..2acfe5097a 100644 --- a/tests/e2e/offline_inference/test_stable_audio_model.py +++ b/tests/e2e/offline_inference/test_stable_audio_model.py @@ -8,6 +8,7 @@ from tests.utils import hardware_test from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform # ruff: noqa: E402 REPO_ROOT = Path(__file__).resolve().parents[2] @@ -41,7 +42,7 @@ def test_stable_audio_model(model_name: str): sampling_params_list=OmniDiffusionSamplingParams( num_inference_steps=4, # Minimal steps for speed guidance_scale=7.0, - generator=torch.Generator("cuda").manual_seed(42), + generator=torch.Generator(current_omni_platform.device_type).manual_seed(42), num_outputs_per_prompt=1, extra_args={ "audio_start_in_s": audio_start_in_s, diff --git a/tests/e2e/offline_inference/test_teacache.py b/tests/e2e/offline_inference/test_teacache.py index d97991779e..96254f5b6e 100644 --- a/tests/e2e/offline_inference/test_teacache.py +++ b/tests/e2e/offline_inference/test_teacache.py @@ -17,6 +17,7 @@ from tests.utils import hardware_test from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.platforms import current_omni_platform # ruff: noqa: E402 REPO_ROOT = Path(__file__).resolve().parents[2] @@ -63,7 +64,7 @@ def test_teacache(model_name: str): width=width, num_inference_steps=num_inference_steps, guidance_scale=0.0, - generator=torch.Generator("cuda").manual_seed(42), + generator=torch.Generator(current_omni_platform.device_type).manual_seed(42), num_outputs_per_prompt=1, # Single output for speed ), )