Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tests/e2e/singlecard/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import pytest
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from vllm.assets.audio import AudioAsset

from tests.e2e.conftest import VllmRunner

Expand All @@ -32,6 +34,10 @@
"OpenBMB/MiniCPM4-0.5B",
]

WHISPER_MODELS = [
"openai-mirror/whisper-large-v3-turbo",
]


@pytest.mark.parametrize("model", MINICPM_MODELS)
def test_minicpm(model) -> None:
Expand All @@ -44,3 +50,26 @@ def test_minicpm(model) -> None:
max_model_len=512,
gpu_memory_utilization=0.7) as runner:
runner.generate_greedy(example_prompts, max_tokens)


@pytest.mark.parametrize("model", WHISPER_MODELS)
def test_whisper(model) -> None:
prompts = ["<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"]
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]

sampling_params = SamplingParams(temperature=0.2,
max_tokens=10,
stop_token_ids=None)

with VllmRunner(snapshot_download(model),
max_model_len=448,
max_num_seqs=5,
dtype="bfloat16",
block_size=128,
gpu_memory_utilization=0.9) as runner:
outputs = runner.generate(prompts=prompts,
audios=audios,
sampling_params=sampling_params)

assert outputs is not None, "Generated outputs should not be None."
assert len(outputs) > 0, "Generated outputs should not be empty."
23 changes: 0 additions & 23 deletions tests/ut/attention/test_attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,26 +320,3 @@ def test_forward_decode_only_swa_seq_len_mismatch(
mock_fused_infer_attention_score.assert_called_once()

assert output.shape == (10, 8, 64)

@patch('torch_npu._npu_reshape_and_cache')
def test_forward_raise_error(self, mock_paged_attention):
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)

metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant

with self.assertRaises(NotImplementedError):
self.impl_error.forward(layer, query, key, value, kv_cache,
metadata, output)
38 changes: 22 additions & 16 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec

from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import (
Expand Down Expand Up @@ -255,6 +255,9 @@ def build(
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]

slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
if isinstance(self.kv_cache_spec, CrossAttentionSpec):
seq_lens = common_attn_metadata.seq_lens
slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32)
attn_state = common_attn_metadata.attn_state

# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
Expand Down Expand Up @@ -496,6 +499,9 @@ def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor,
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
if self.attn_type == AttentionType.ENCODER_DECODER:
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens,
dim=0).tolist()
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.seq_lens.shape[0]
Expand Down Expand Up @@ -577,7 +583,7 @@ def forward_fused_infer_attention(self, query: torch.Tensor,
= self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
query = query[:num_tokens]
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER:
key = key[:num_tokens]
value = value[:num_tokens]
# Get workspace from cache or calculate it if not present.
Expand Down Expand Up @@ -669,23 +675,29 @@ def reshape_and_cache(
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER)
if get_ascend_device_type() == AscendDeviceType.A5:
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
# If it's necessary, the slots should be sliced.
torch_npu.npu_scatter_pa_kv_cache(
key=key[:attn_metadata.num_actual_tokens],
value=value[:attn_metadata.num_actual_tokens].contiguous(),
key=key[:attn_metadata.num_actual_tokens]
if not encoder_decoder else key,
value=value[:attn_metadata.num_actual_tokens].contiguous()
if not encoder_decoder else value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_mapping=slots)
else:
torch_npu._npu_reshape_and_cache(
key=key[:attn_metadata.num_actual_tokens],
value=value[:attn_metadata.num_actual_tokens],
key=key[:attn_metadata.num_actual_tokens]
if not encoder_decoder else key,
value=value[:attn_metadata.num_actual_tokens]
if not encoder_decoder else value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots[:attn_metadata.num_actual_tokens])
slot_indices=slots[:attn_metadata.num_actual_tokens]
if not encoder_decoder else slots)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
return key, value
Expand Down Expand Up @@ -741,18 +753,12 @@ def forward(
" for AscendAttentionBackendImpl")

assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/Decoder cross-attention "
"is not implemented for "
"PallasAttentionBackendImpl")
Comment thread
gh924 marked this conversation as resolved.
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
key, value = self.reshape_and_cache(key, value, kv_cache,
attn_metadata)
if key is not None and value is not None:
key, value = self.reshape_and_cache(key, value, kv_cache,
attn_metadata)
# pooling model branch
if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention(
Expand Down
8 changes: 8 additions & 0 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

# encoder-decoder models currently only support piecewise mode
if model_config and model_config.is_encoder_decoder is True:
Comment thread
gh924 marked this conversation as resolved.
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
logger.warning(
"encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE "
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

# get custom compile backend for graph fusion
compilation_config.oot_compiler = cls.get_compile_backend()

Expand Down
73 changes: 44 additions & 29 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec,
from vllm.v1.kv_cache_interface import (AttentionSpec, CrossAttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
Expand Down Expand Up @@ -315,7 +315,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
# the block_sizes in the kv cache config.
self.input_batch = NPUInputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_model_len=max(self.model_config.max_model_len,
self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
Expand Down Expand Up @@ -485,7 +486,8 @@ def _prepare_inputs(
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor], int]:
Optional[torch.Tensor], Optional[torch.Tensor], int, dict[str,
Any]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
Expand Down Expand Up @@ -729,7 +731,11 @@ def _prepare_inputs(

# _prepare_inputs may reorder the batch, so we must gather
# multi-modal outputs after that to ensure the correct order
if self.is_multimodal_model:
if vllm_version_is('0.13.0'):
model_kwargs = self._init_model_kwargs(num_input_tokens)
else:
model_kwargs = self._init_model_kwargs()
if self.is_multimodal_model and not self.model_config.is_encoder_decoder:
self.multimodal_cpu_fields = ["grid_thw"]
self._prepare_multimodal_fields()
with self.maybe_get_ec_connector_output(
Expand Down Expand Up @@ -796,6 +802,13 @@ def _prepare_inputs(
else:
positions = self.positions.gpu[:num_input_tokens]

# Run the encoder, just like we do with other multimodal inputs.
if self.model_config.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs:
input_ids = self.input_ids.gpu[:total_num_scheduled_tokens]
positions = self.positions.gpu[:total_num_scheduled_tokens]
encoder_outputs = self._execute_mm_encoder(scheduler_output)
model_kwargs.update({"encoder_outputs": encoder_outputs})

# type: ignore
if get_pp_group().is_first_rank:
intermediate_tensors = None
Expand Down Expand Up @@ -880,6 +893,11 @@ def _prepare_inputs(
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
scheduler_output.num_scheduled_tokens or {},
kv_cache_group_spec.kv_cache_spec,
self.input_batch.num_reqs,
)
if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
Expand Down Expand Up @@ -973,7 +991,8 @@ def _prepare_inputs(
decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=self.long_seq_metadata,
max_seq_len=0,
)
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_cpu=encoder_seq_lens_cpu)

if self.speculative_config and self.pcp_size * self.dcp_size > 1:
# For pcp + spec decode, we flatten block_table
Expand Down Expand Up @@ -1055,7 +1074,7 @@ def _prepare_inputs(
num_input_tokens, num_tokens_across_dp,
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
input_ids, inputs_embeds, intermediate_tensors,
max_num_scheduled_tokens)
max_num_scheduled_tokens, model_kwargs)

# all-gather one hidden-states in sp scene
@staticmethod
Expand Down Expand Up @@ -1087,22 +1106,13 @@ def _all_gather_hidden_states_and_aux(hidden_states):
def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens,
input_ids, positions,
intermediate_tensors,
inputs_embeds):
inputs_embeds, model_kwargs):
assert self.model is not None
if vllm_version_is('0.13.0'):
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**self._init_model_kwargs(maybe_padded_num_tokens))
else:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**self._init_model_kwargs())
hidden_states = self.model(input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs)

forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
Expand Down Expand Up @@ -1461,9 +1471,9 @@ def execute_model(
(attn_metadata, positions, num_scheduled_tokens_np,
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
intermediate_tensors,
max_query_len) = (self._prepare_inputs(scheduler_output,
intermediate_tensors))
intermediate_tensors, max_query_len,
model_kwargs) = (self._prepare_inputs(scheduler_output,
intermediate_tensors))

if self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process()
Expand Down Expand Up @@ -1508,7 +1518,7 @@ def execute_model(

hidden_states = self._generate_process_reqs_hidden_states(
maybe_padded_num_tokens, input_ids, positions,
intermediate_tensors, inputs_embeds)
intermediate_tensors, inputs_embeds, model_kwargs)

self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfer(
Expand Down Expand Up @@ -2148,7 +2158,7 @@ def _dummy_run(
num_sampled_tokens):
# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_padded <= self.max_num_tokens
if self.is_multimodal_model:
if self.is_multimodal_model and not self.model_config.is_encoder_decoder:
Comment thread
gh924 marked this conversation as resolved.
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
elif self.enable_prompt_embeds:
Expand Down Expand Up @@ -2542,7 +2552,7 @@ def _reshape_kv_cache_tensors(

# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, FullAttentionSpec):
if isinstance(kv_cache_spec, AttentionSpec):
raw_dsa_k_tensor = None
if self.use_sparse:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
Expand Down Expand Up @@ -2717,7 +2727,8 @@ def may_reinitialize_input_batch(self,
"for more details.")
self.input_batch = NPUInputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_model_len=max(self.model_config.max_model_len,
self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
Expand Down Expand Up @@ -2885,7 +2896,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype)
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
Expand Down
Loading