From 52a4f360fbe421f74caa4986ba05cc2b9131b33f Mon Sep 17 00:00:00 2001 From: Aditya Tewari Date: Sun, 30 Nov 2025 22:26:20 +0000 Subject: [PATCH 1/7] enable encoder-decoder support - whisper should work with these changes Signed-off-by: Aditya Tewari --- csrc/cpu/cpu_attn.cpp | 1 - .../multimodal/generation/test_whisper.py | 26 +++++++++++++++ vllm/v1/attention/backends/cpu_attn.py | 32 +++++++++---------- vllm/v1/worker/utils.py | 14 ++++---- 4 files changed, 49 insertions(+), 24 deletions(-) diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp index 92f8bee5a47a..02c722ba031a 100644 --- a/csrc/cpu/cpu_attn.cpp +++ b/csrc/cpu/cpu_attn.cpp @@ -117,7 +117,6 @@ torch::Tensor get_scheduler_metadata( input.casual = casual; input.isa = isa; input.enable_kv_split = enable_kv_split; - TORCH_CHECK(casual, "Only supports casual mask for now."); VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index eca2b61e37d5..880f9c0768a9 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -145,3 +145,29 @@ def test_models_distributed( tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend, ) + + +@pytest.mark.parametrize("model", ["openai/whisper-large-v3"]) +@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) +@pytest.mark.cpu_model +def test_whisper_cpu(vllm_runner, model, dtype): + prompt_list = PROMPTS * 4 + expected_list = EXPECTED[model] * 4 + with vllm_runner( + model, + dtype=dtype, + max_model_len=448, + ) as vllm_model: + llm = vllm_model.llm + + sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + max_tokens=200, + ) + + outputs = llm.generate(prompt_list, sampling_params) + + for output, expected in zip(outputs, expected_list): + print(output.outputs[0].text) + assert output.outputs[0].text == expected diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index fed7dcdf293b..0a1a50fb097f 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -50,11 +50,13 @@ def get_name() -> str: @classmethod def supports_attn_type(cls, attn_type: str) -> bool: - """CPU attention supports decoder and encoder-only attention.""" + """CPU attention supports decoder, + encoder-only and encoder-decoder attention.""" return attn_type in ( AttentionType.DECODER, AttentionType.ENCODER, AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, ) @staticmethod @@ -172,21 +174,19 @@ def build( block_table_tensor = block_table_tensor[:num_decodes] sheduler_metadata = None - if causal: - # for decode batch, use the custom kernel - sheduler_metadata = ops.cpu_attn_get_scheduler_metadata( - num_reqs=num_reqs, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - seq_lens=seq_lens, - dtype=self.dtype, - query_start_loc=query_start_loc, - causal=causal, - sliding_window_size=self.window_size, - isa=self.isa, - enable_kv_split=True, - ) + sheduler_metadata = ops.cpu_attn_get_scheduler_metadata( + num_reqs=num_reqs, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + seq_lens=seq_lens, + dtype=self.dtype, + query_start_loc=query_start_loc, + causal=causal, + sliding_window_size=self.window_size, + isa=self.isa, + enable_kv_split=True, + ) attn_metadata = CPUAttentionMetadata( isa=self.isa, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 427a0d296b25..7418efd38f7e 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -305,6 +305,7 @@ def bind_kv_cache( for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] + if len(layer_names) > 1: # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer @@ -313,13 +314,12 @@ def bind_kv_cache( # TODO - analyze where runner_kv_caches is used and the right # way to ensure it properly reflects multiple attention layers # in the same decoder block. - if current_platform.is_cuda_alike() or current_platform.is_xpu(): - # We know that the GPU runner is not impacted by this - # case. Some test code depends on runner_kv_caches, but - # not in a way that's impacted by ignoring this. - pass - else: - raise NotImplementedError + + # We know that the GPU / CPU runner is not impacted by this + # case. Some test code depends on runner_kv_caches, but + # not in a way that's impacted by ignoring this. + pass + layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name]) From 9165cbfaa9f089756d00149ed4080c799adb1161 Mon Sep 17 00:00:00 2001 From: Aditya Tewari Date: Thu, 4 Dec 2025 15:07:31 +0000 Subject: [PATCH 2/7] Address review comments Signed-off-by: Aditya Tewari --- vllm/v1/attention/backends/cpu_attn.py | 8 ++++++-- vllm/v1/worker/utils.py | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 0a1a50fb097f..c448899655b1 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -21,7 +21,7 @@ CommonAttentionMetadata, split_decodes_and_prefills, ) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec logger = init_logger(__name__) @@ -138,6 +138,7 @@ def __init__( self.window_size = -1 self.block_size = vllm_config.cache_config.block_size self.isa = _get_attn_isa(self.dtype, self.block_size) + self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec) def build( self, @@ -153,7 +154,10 @@ def build( seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - causal = common_attn_metadata.causal + if self.is_cross_attention: + causal = False + else: + causal = common_attn_metadata.causal sdpa_start_loc = query_start_loc num_decode_tokens = 0 diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 7418efd38f7e..83de58f39972 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -12,7 +12,6 @@ from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry -from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec From cd1ecfb691ded0a89429669d1f61571732eec3cf Mon Sep 17 00:00:00 2001 From: Aditya Tewari Date: Sun, 30 Nov 2025 22:26:20 +0000 Subject: [PATCH 3/7] enable encoder-decoder support - whisper should work with these changes Signed-off-by: Aditya Tewari --- csrc/cpu/cpu_attn.cpp | 1 - .../multimodal/generation/test_whisper.py | 26 +++++++++++++++ vllm/v1/attention/backends/cpu_attn.py | 32 +++++++++---------- vllm/v1/worker/utils.py | 14 ++++---- 4 files changed, 49 insertions(+), 24 deletions(-) diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp index 92f8bee5a47a..02c722ba031a 100644 --- a/csrc/cpu/cpu_attn.cpp +++ b/csrc/cpu/cpu_attn.cpp @@ -117,7 +117,6 @@ torch::Tensor get_scheduler_metadata( input.casual = casual; input.isa = isa; input.enable_kv_split = enable_kv_split; - TORCH_CHECK(casual, "Only supports casual mask for now."); VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index eca2b61e37d5..880f9c0768a9 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -145,3 +145,29 @@ def test_models_distributed( tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend, ) + + +@pytest.mark.parametrize("model", ["openai/whisper-large-v3"]) +@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) +@pytest.mark.cpu_model +def test_whisper_cpu(vllm_runner, model, dtype): + prompt_list = PROMPTS * 4 + expected_list = EXPECTED[model] * 4 + with vllm_runner( + model, + dtype=dtype, + max_model_len=448, + ) as vllm_model: + llm = vllm_model.llm + + sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + max_tokens=200, + ) + + outputs = llm.generate(prompt_list, sampling_params) + + for output, expected in zip(outputs, expected_list): + print(output.outputs[0].text) + assert output.outputs[0].text == expected diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index fed7dcdf293b..0a1a50fb097f 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -50,11 +50,13 @@ def get_name() -> str: @classmethod def supports_attn_type(cls, attn_type: str) -> bool: - """CPU attention supports decoder and encoder-only attention.""" + """CPU attention supports decoder, + encoder-only and encoder-decoder attention.""" return attn_type in ( AttentionType.DECODER, AttentionType.ENCODER, AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, ) @staticmethod @@ -172,21 +174,19 @@ def build( block_table_tensor = block_table_tensor[:num_decodes] sheduler_metadata = None - if causal: - # for decode batch, use the custom kernel - sheduler_metadata = ops.cpu_attn_get_scheduler_metadata( - num_reqs=num_reqs, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - seq_lens=seq_lens, - dtype=self.dtype, - query_start_loc=query_start_loc, - causal=causal, - sliding_window_size=self.window_size, - isa=self.isa, - enable_kv_split=True, - ) + sheduler_metadata = ops.cpu_attn_get_scheduler_metadata( + num_reqs=num_reqs, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + seq_lens=seq_lens, + dtype=self.dtype, + query_start_loc=query_start_loc, + causal=causal, + sliding_window_size=self.window_size, + isa=self.isa, + enable_kv_split=True, + ) attn_metadata = CPUAttentionMetadata( isa=self.isa, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 427a0d296b25..7418efd38f7e 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -305,6 +305,7 @@ def bind_kv_cache( for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] + if len(layer_names) > 1: # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer @@ -313,13 +314,12 @@ def bind_kv_cache( # TODO - analyze where runner_kv_caches is used and the right # way to ensure it properly reflects multiple attention layers # in the same decoder block. - if current_platform.is_cuda_alike() or current_platform.is_xpu(): - # We know that the GPU runner is not impacted by this - # case. Some test code depends on runner_kv_caches, but - # not in a way that's impacted by ignoring this. - pass - else: - raise NotImplementedError + + # We know that the GPU / CPU runner is not impacted by this + # case. Some test code depends on runner_kv_caches, but + # not in a way that's impacted by ignoring this. + pass + layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name]) From 72f95ed67b6ce5db3693a85fe1ed53f745b33bd0 Mon Sep 17 00:00:00 2001 From: Aditya Tewari Date: Thu, 4 Dec 2025 15:07:31 +0000 Subject: [PATCH 4/7] Address review comments Signed-off-by: Aditya Tewari --- vllm/v1/attention/backends/cpu_attn.py | 8 ++++++-- vllm/v1/worker/utils.py | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 0a1a50fb097f..c448899655b1 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -21,7 +21,7 @@ CommonAttentionMetadata, split_decodes_and_prefills, ) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec logger = init_logger(__name__) @@ -138,6 +138,7 @@ def __init__( self.window_size = -1 self.block_size = vllm_config.cache_config.block_size self.isa = _get_attn_isa(self.dtype, self.block_size) + self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec) def build( self, @@ -153,7 +154,10 @@ def build( seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - causal = common_attn_metadata.causal + if self.is_cross_attention: + causal = False + else: + causal = common_attn_metadata.causal sdpa_start_loc = query_start_loc num_decode_tokens = 0 diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 7418efd38f7e..83de58f39972 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -12,7 +12,6 @@ from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry -from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec From f037ce01fc5ab31c16576b5c86df4faea29b8249 Mon Sep 17 00:00:00 2001 From: Aditya Tewari Date: Fri, 5 Dec 2025 10:13:10 +0000 Subject: [PATCH 5/7] address review comments Signed-off-by: Aditya Tewari --- .../scripts/hardware_ci/run-cpu-test-arm.sh | 5 +++ .../multimodal/generation/test_whisper.py | 34 ++++--------------- vllm/v1/attention/backends/cpu_attn.py | 6 +--- vllm/v1/worker/utils.py | 18 ++++++---- 4 files changed, 24 insertions(+), 39 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh index b5f6b2494792..56a4ba677731 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh @@ -36,6 +36,11 @@ function cpu_tests() { set -e python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + # Run model tests + docker exec cpu-test bash -c " + set -e + pytest -x -v -s tests/models/multimodal/generation/test_whisper.py -m cpu_model" + # Run kernel tests docker exec cpu-test bash -c " set -e diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 880f9c0768a9..5be3f4ef4e14 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -92,13 +92,14 @@ def run_test( *, tensor_parallel_size: int, distributed_executor_backend: str | None = None, + dtype: str = "half", ) -> None: prompt_list = PROMPTS * 10 expected_list = EXPECTED[model] * 10 with vllm_runner( model, - dtype="half", + dtype=dtype, max_model_len=448, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, @@ -118,14 +119,17 @@ def run_test( assert output.outputs[0].text == expected +@pytest.mark.cpu_model @pytest.mark.core_model @pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) +@pytest.mark.parametrize("dtype", ["half"]) @create_new_process_for_each_test() -def test_models(vllm_runner, model) -> None: +def test_models(vllm_runner, model, dtype) -> None: run_test( vllm_runner, model, tensor_parallel_size=1, + dtype=dtype, ) @@ -145,29 +149,3 @@ def test_models_distributed( tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend, ) - - -@pytest.mark.parametrize("model", ["openai/whisper-large-v3"]) -@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) -@pytest.mark.cpu_model -def test_whisper_cpu(vllm_runner, model, dtype): - prompt_list = PROMPTS * 4 - expected_list = EXPECTED[model] * 4 - with vllm_runner( - model, - dtype=dtype, - max_model_len=448, - ) as vllm_model: - llm = vllm_model.llm - - sampling_params = SamplingParams( - temperature=0, - top_p=1.0, - max_tokens=200, - ) - - outputs = llm.generate(prompt_list, sampling_params) - - for output, expected in zip(outputs, expected_list): - print(output.outputs[0].text) - assert output.outputs[0].text == expected diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index c448899655b1..394d0c2f6713 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -154,10 +154,7 @@ def build( seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - if self.is_cross_attention: - causal = False - else: - causal = common_attn_metadata.causal + causal = False if self.is_cross_attention else common_attn_metadata.causal sdpa_start_loc = query_start_loc num_decode_tokens = 0 @@ -177,7 +174,6 @@ def build( query_start_loc = query_start_loc[: num_decodes + 1] block_table_tensor = block_table_tensor[:num_decodes] - sheduler_metadata = None sheduler_metadata = ops.cpu_attn_get_scheduler_metadata( num_reqs=num_reqs, num_heads=self.num_heads, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 83de58f39972..3638c8a93733 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -12,6 +12,7 @@ from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry +from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec @@ -304,7 +305,6 @@ def bind_kv_cache( for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] - if len(layer_names) > 1: # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer @@ -313,11 +313,17 @@ def bind_kv_cache( # TODO - analyze where runner_kv_caches is used and the right # way to ensure it properly reflects multiple attention layers # in the same decoder block. - - # We know that the GPU / CPU runner is not impacted by this - # case. Some test code depends on runner_kv_caches, but - # not in a way that's impacted by ignoring this. - pass + if ( + current_platform.is_cuda_alike() + or current_platform.is_xpu() + or current_platform.is_cpu() + ): + # We know that the GPU / CPU runner is not impacted by this + # case. Some test code depends on runner_kv_caches, but + # not in a way that's impacted by ignoring this. + pass + else: + raise NotImplementedError layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name]) From eed969a87163d4d11b53d2ce99101862b66f8533 Mon Sep 17 00:00:00 2001 From: Aditya Tewari Date: Fri, 5 Dec 2025 11:50:36 +0000 Subject: [PATCH 6/7] missed update Signed-off-by: Aditya Tewari --- .../multimodal/generation/test_whisper.py | 26 ------------------- vllm/v1/worker/utils.py | 2 -- 2 files changed, 28 deletions(-) diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index ad6935452fee..5be3f4ef4e14 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -149,29 +149,3 @@ def test_models_distributed( tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend, ) - - -@pytest.mark.parametrize("model", ["openai/whisper-large-v3"]) -@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) -@pytest.mark.cpu_model -def test_whisper_cpu(vllm_runner, model, dtype): - prompt_list = PROMPTS * 4 - expected_list = EXPECTED[model] * 4 - with vllm_runner( - model, - dtype=dtype, - max_model_len=448, - ) as vllm_model: - llm = vllm_model.llm - - sampling_params = SamplingParams( - temperature=0, - top_p=1.0, - max_tokens=200, - ) - - outputs = llm.generate(prompt_list, sampling_params) - - for output, expected in zip(outputs, expected_list): - print(output.outputs[0].text) - assert output.outputs[0].text == expected diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 61314888a964..126b46b34051 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -305,7 +305,6 @@ def bind_kv_cache( for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] - if len(layer_names) > 1: # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer @@ -325,7 +324,6 @@ def bind_kv_cache( pass else: raise NotImplementedError - layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name]) From ec88ad6fa7d098b5e499b81cde79732d6e4867fe Mon Sep 17 00:00:00 2001 From: Aditya Tewari Date: Fri, 5 Dec 2025 14:47:47 +0000 Subject: [PATCH 7/7] fix whisper test Signed-off-by: Aditya Tewari --- .../models/multimodal/generation/test_whisper.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 5be3f4ef4e14..8c99b6b4690a 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -119,7 +119,6 @@ def run_test( assert output.outputs[0].text == expected -@pytest.mark.cpu_model @pytest.mark.core_model @pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) @pytest.mark.parametrize("dtype", ["half"]) @@ -133,6 +132,20 @@ def test_models(vllm_runner, model, dtype) -> None: ) +@pytest.mark.cpu_model +@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models_cpu(vllm_runner, model, dtype) -> None: + # @create_new_process_for_each_test() does not work for some runners + # TODO: to fix cpu privilege issues in run-cpu-test-arm.sh + run_test( + vllm_runner, + model, + tensor_parallel_size=1, + dtype=dtype, + ) + + @multi_gpu_test(num_gpus=2) @pytest.mark.core_model @pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])