From 859905941356070be72e84783ad76a93d3b186ef Mon Sep 17 00:00:00 2001 From: Aoxuan Chen <43376869+chenaoxuan@users.noreply.github.com> Date: Thu, 8 Jan 2026 09:15:55 +0800 Subject: [PATCH] Support for cross-attention and whisper model Signed-off-by: gh924 --- tests/e2e/singlecard/test_models.py | 29 ++++++++++ tests/ut/attention/test_attention_v1.py | 23 -------- vllm_ascend/attention/attention_v1.py | 38 +++++++------ vllm_ascend/platform.py | 8 +++ vllm_ascend/worker/model_runner_v1.py | 73 +++++++++++++++---------- 5 files changed, 103 insertions(+), 68 deletions(-) diff --git a/tests/e2e/singlecard/test_models.py b/tests/e2e/singlecard/test_models.py index 8068eda5158..e0464a55b90 100644 --- a/tests/e2e/singlecard/test_models.py +++ b/tests/e2e/singlecard/test_models.py @@ -21,6 +21,8 @@ import pytest from modelscope import snapshot_download # type: ignore +from vllm import SamplingParams +from vllm.assets.audio import AudioAsset from tests.e2e.conftest import VllmRunner @@ -32,6 +34,10 @@ "OpenBMB/MiniCPM4-0.5B", ] +WHISPER_MODELS = [ + "openai-mirror/whisper-large-v3-turbo", +] + @pytest.mark.parametrize("model", MINICPM_MODELS) def test_minicpm(model) -> None: @@ -44,3 +50,26 @@ def test_minicpm(model) -> None: max_model_len=512, gpu_memory_utilization=0.7) as runner: runner.generate_greedy(example_prompts, max_tokens) + + +@pytest.mark.parametrize("model", WHISPER_MODELS) +def test_whisper(model) -> None: + prompts = ["<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"] + audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate] + + sampling_params = SamplingParams(temperature=0.2, + max_tokens=10, + stop_token_ids=None) + + with VllmRunner(snapshot_download(model), + max_model_len=448, + max_num_seqs=5, + dtype="bfloat16", + block_size=128, + gpu_memory_utilization=0.9) as runner: + outputs = runner.generate(prompts=prompts, + audios=audios, + sampling_params=sampling_params) + + assert outputs is not None, "Generated outputs should not be None." + assert len(outputs) > 0, "Generated outputs should not be empty." diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index c5f0fc1b4d5..0dbd5837f93 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -320,26 +320,3 @@ def test_forward_decode_only_swa_seq_len_mismatch( mock_fused_infer_attention_score.assert_called_once() assert output.shape == (10, 8, 64) - - @patch('torch_npu._npu_reshape_and_cache') - def test_forward_raise_error(self, mock_paged_attention): - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - output = torch.empty_like(query) - - metadata = self.attn_metadata - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.query_lens = torch.tensor([10]) - metadata.seq_lens = torch.tensor([10]) - metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 0 - metadata.num_prefills = 10 - layer = self.layer_no_quant - - with self.assertRaises(NotImplementedError): - self.impl_error.forward(layer, query, key, value, kv_cache, - metadata, output) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d19d3369b99..69746d2ccda 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -32,7 +32,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder) from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.context_parallel.common_cp import ( @@ -255,6 +255,9 @@ def build( seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] + if isinstance(self.kv_cache_spec, CrossAttentionSpec): + seq_lens = common_attn_metadata.seq_lens + slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32) attn_state = common_attn_metadata.attn_state # Get attn_mask and swa_mask from singleton AttentionMaskBuilder @@ -496,6 +499,9 @@ def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, block_size = 128 block_table = None actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q + if self.attn_type == AttentionType.ENCODER_DECODER: + actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, + dim=0).tolist() elif attn_metadata.attn_state == \ AscendAttentionState.PrefillCacheHit: batch_size = attn_metadata.seq_lens.shape[0] @@ -577,7 +583,7 @@ def forward_fused_infer_attention(self, query: torch.Tensor, = self._get_fia_params(key, value, attn_metadata) num_tokens = attn_metadata.actual_seq_lengths_q[-1] query = query[:num_tokens] - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER: key = key[:num_tokens] value = value[:num_tokens] # Get workspace from cache or calculate it if not present. @@ -669,23 +675,29 @@ def reshape_and_cache( if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping + encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER) if get_ascend_device_type() == AscendDeviceType.A5: # TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping. # Should check if the 0 dim of slot_mapping must equal to the 0 dim of key. # If it's necessary, the slots should be sliced. torch_npu.npu_scatter_pa_kv_cache( - key=key[:attn_metadata.num_actual_tokens], - value=value[:attn_metadata.num_actual_tokens].contiguous(), + key=key[:attn_metadata.num_actual_tokens] + if not encoder_decoder else key, + value=value[:attn_metadata.num_actual_tokens].contiguous() + if not encoder_decoder else value, key_cache=self.key_cache, value_cache=self.value_cache, slot_mapping=slots) else: torch_npu._npu_reshape_and_cache( - key=key[:attn_metadata.num_actual_tokens], - value=value[:attn_metadata.num_actual_tokens], + key=key[:attn_metadata.num_actual_tokens] + if not encoder_decoder else key, + value=value[:attn_metadata.num_actual_tokens] + if not encoder_decoder else value, key_cache=self.key_cache, value_cache=self.value_cache, - slot_indices=slots[:attn_metadata.num_actual_tokens]) + slot_indices=slots[:attn_metadata.num_actual_tokens] + if not encoder_decoder else slots) if self.is_kv_producer: attn_metadata.reshape_cache_event.record() return key, value @@ -741,18 +753,12 @@ def forward( " for AscendAttentionBackendImpl") assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - attn_type = self.attn_type - if attn_type not in [ - AttentionType.DECODER, AttentionType.ENCODER_ONLY - ]: - raise NotImplementedError("Encoder/Decoder cross-attention " - "is not implemented for " - "PallasAttentionBackendImpl") num_tokens = query.shape[0] if attn_metadata is None: return output.fill_(0) - key, value = self.reshape_and_cache(key, value, kv_cache, - attn_metadata) + if key is not None and value is not None: + key, value = self.reshape_and_cache(key, value, kv_cache, + attn_metadata) # pooling model branch if attn_metadata.model_runner_type == "pooling": attn_output = self._forward_encoder_attention( diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 06f8be7bca2..f37a1ad8785 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -238,6 +238,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # encoder-decoder models currently only support piecewise mode + if model_config and model_config.is_encoder_decoder is True: + if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + logger.warning( + "encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE " + ) + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # get custom compile backend for graph fusion compilation_config.oot_compiler = cls.get_compile_backend() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 86fa7a4b071..337ff003ca9 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -55,7 +55,7 @@ from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (AttentionSpec, +from vllm.v1.kv_cache_interface import (AttentionSpec, CrossAttentionSpec, EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, @@ -315,7 +315,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # the block_sizes in the kv cache config. self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.model_config.max_model_len, + max_model_len=max(self.model_config.max_model_len, + self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -485,7 +486,8 @@ def _prepare_inputs( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor, int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor], int]: + Optional[torch.Tensor], Optional[torch.Tensor], int, dict[str, + Any]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -729,7 +731,11 @@ def _prepare_inputs( # _prepare_inputs may reorder the batch, so we must gather # multi-modal outputs after that to ensure the correct order - if self.is_multimodal_model: + if vllm_version_is('0.13.0'): + model_kwargs = self._init_model_kwargs(num_input_tokens) + else: + model_kwargs = self._init_model_kwargs() + if self.is_multimodal_model and not self.model_config.is_encoder_decoder: self.multimodal_cpu_fields = ["grid_thw"] self._prepare_multimodal_fields() with self.maybe_get_ec_connector_output( @@ -796,6 +802,13 @@ def _prepare_inputs( else: positions = self.positions.gpu[:num_input_tokens] + # Run the encoder, just like we do with other multimodal inputs. + if self.model_config.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs: + input_ids = self.input_ids.gpu[:total_num_scheduled_tokens] + positions = self.positions.gpu[:total_num_scheduled_tokens] + encoder_outputs = self._execute_mm_encoder(scheduler_output) + model_kwargs.update({"encoder_outputs": encoder_outputs}) + # type: ignore if get_pp_group().is_first_rank: intermediate_tensors = None @@ -880,6 +893,11 @@ def _prepare_inputs( # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens( + scheduler_output.num_scheduled_tokens or {}, + kv_cache_group_spec.kv_cache_spec, + self.input_batch.num_reqs, + ) if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to @@ -973,7 +991,8 @@ def _prepare_inputs( decode_token_per_req=self.decode_token_per_req, prefill_context_parallel_metadata=self.long_seq_metadata, max_seq_len=0, - ) + encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_cpu=encoder_seq_lens_cpu) if self.speculative_config and self.pcp_size * self.dcp_size > 1: # For pcp + spec decode, we flatten block_table @@ -1055,7 +1074,7 @@ def _prepare_inputs( num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, logits_indices, spec_decode_metadata, input_ids, inputs_embeds, intermediate_tensors, - max_num_scheduled_tokens) + max_num_scheduled_tokens, model_kwargs) # all-gather one hidden-states in sp scene @staticmethod @@ -1087,22 +1106,13 @@ def _all_gather_hidden_states_and_aux(hidden_states): def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, - inputs_embeds): + inputs_embeds, model_kwargs): assert self.model is not None - if vllm_version_is('0.13.0'): - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **self._init_model_kwargs(maybe_padded_num_tokens)) - else: - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **self._init_model_kwargs()) + hidden_states = self.model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \ @@ -1461,9 +1471,9 @@ def execute_model( (attn_metadata, positions, num_scheduled_tokens_np, num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, logits_indices, spec_decode_metadata, input_ids, inputs_embeds, - intermediate_tensors, - max_query_len) = (self._prepare_inputs(scheduler_output, - intermediate_tensors)) + intermediate_tensors, max_query_len, + model_kwargs) = (self._prepare_inputs(scheduler_output, + intermediate_tensors)) if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() @@ -1508,7 +1518,7 @@ def execute_model( hidden_states = self._generate_process_reqs_hidden_states( maybe_padded_num_tokens, input_ids, positions, - intermediate_tensors, inputs_embeds) + intermediate_tensors, inputs_embeds, model_kwargs) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( @@ -2148,7 +2158,7 @@ def _dummy_run( num_sampled_tokens): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens - if self.is_multimodal_model: + if self.is_multimodal_model and not self.model_config.is_encoder_decoder: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] elif self.enable_prompt_embeds: @@ -2542,7 +2552,7 @@ def _reshape_kv_cache_tensors( # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue - if isinstance(kv_cache_spec, FullAttentionSpec): + if isinstance(kv_cache_spec, AttentionSpec): raw_dsa_k_tensor = None if self.use_sparse: raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore @@ -2717,7 +2727,8 @@ def may_reinitialize_input_batch(self, "for more details.") self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.model_config.max_model_len, + max_model_len=max(self.model_config.max_model_len, + self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -2885,7 +2896,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError + kv_cache_spec[layer_name] = CrossAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype) else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}")