diff --git a/.github/workflows/bot_pr_create.yaml b/.github/workflows/bot_pr_create.yaml index bdb75d25ae0..3fd1adfd89b 100644 --- a/.github/workflows/bot_pr_create.yaml +++ b/.github/workflows/bot_pr_create.yaml @@ -37,7 +37,7 @@ jobs: steps: - name: Get vLLM version run: | - VLLM_COMMIT=14acf429ac08b6d538ca6feb3e06b6d13895804d + VLLM_COMMIT=29e48707e8144b78dd5d756f793c26a405043f3d echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> "$GITHUB_ENV" - name: Checkout repository diff --git a/.github/workflows/dockerfiles/Dockerfile.lint b/.github/workflows/dockerfiles/Dockerfile.lint index bb27cf537b7..3cb23c9fcf2 100644 --- a/.github/workflows/dockerfiles/Dockerfile.lint +++ b/.github/workflows/dockerfiles/Dockerfile.lint @@ -27,7 +27,7 @@ RUN apt-get update -y && \ ARG VLLM_REPO=https://github.com/vllm-project/vllm.git # For lint purpose, actually we need make a main2main matching. -ARG VLLM_COMMIT=14acf429ac08b6d538ca6feb3e06b6d13895804d +ARG VLLM_COMMIT=29e48707e8144b78dd5d756f793c26a405043f3d RUN git init /vllm-workspace/vllm && \ git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ git -C /vllm-workspace/vllm checkout FETCH_HEAD diff --git a/.github/workflows/pr_test_full.yaml b/.github/workflows/pr_test_full.yaml index 33c5c1046cd..49971e73c41 100644 --- a/.github/workflows/pr_test_full.yaml +++ b/.github/workflows/pr_test_full.yaml @@ -80,7 +80,7 @@ jobs: name: e2e-full strategy: matrix: - vllm_version: [14acf429ac08b6d538ca6feb3e06b6d13895804d] + vllm_version: [29e48707e8144b78dd5d756f793c26a405043f3d] needs: [changes] if: ${{ needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.e2e_tracker == true }} uses: ./.github/workflows/_e2e_test.yaml diff --git a/.github/workflows/pr_test_light.yaml b/.github/workflows/pr_test_light.yaml index 3984d6c5b5c..789644f3a8f 100644 --- a/.github/workflows/pr_test_light.yaml +++ b/.github/workflows/pr_test_light.yaml @@ -41,7 +41,7 @@ jobs: lint: uses: ./.github/workflows/_pre_commit.yml with: - vllm: 14acf429ac08b6d538ca6feb3e06b6d13895804d + vllm: 29e48707e8144b78dd5d756f793c26a405043f3d changes: runs-on: linux-aarch64-a2b3-0 outputs: @@ -90,7 +90,7 @@ jobs: if: ${{ needs.lint.result == 'success' && (needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.ut_tracker == 'true') }} strategy: matrix: - vllm_version: [14acf429ac08b6d538ca6feb3e06b6d13895804d] + vllm_version: [29e48707e8144b78dd5d756f793c26a405043f3d] uses: ./.github/workflows/_unit_test.yaml with: vllm: ${{ matrix.vllm_version }} @@ -102,7 +102,7 @@ jobs: name: e2e-light strategy: matrix: - vllm_version: [14acf429ac08b6d538ca6feb3e06b6d13895804d] + vllm_version: [29e48707e8144b78dd5d756f793c26a405043f3d] # Note (yikun): If CI resource are limited we can split job into two chain jobs needs: [lint, changes] # only trigger e2e test after lint passed and the change is e2e related with pull request. diff --git a/.github/workflows/schedule_codecov_refresh.yaml b/.github/workflows/schedule_codecov_refresh.yaml index 1b8e6d3837b..b9401208e4a 100644 --- a/.github/workflows/schedule_codecov_refresh.yaml +++ b/.github/workflows/schedule_codecov_refresh.yaml @@ -33,7 +33,7 @@ jobs: name: refresh codecov strategy: matrix: - vllm_version: [14acf429ac08b6d538ca6feb3e06b6d13895804d] + vllm_version: [29e48707e8144b78dd5d756f793c26a405043f3d] uses: ./.github/workflows/_unit_test.yaml with: vllm: ${{ matrix.vllm_version }} diff --git a/Dockerfile b/Dockerfile index 8e6b0839dd2..038cde47265 100644 --- a/Dockerfile +++ b/Dockerfile @@ -48,7 +48,7 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -ARG VLLM_COMMIT=14acf429ac08b6d538ca6feb3e06b6d13895804d +ARG VLLM_COMMIT=29e48707e8144b78dd5d756f793c26a405043f3d RUN git init /vllm-workspace/vllm && \ git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ git -C /vllm-workspace/vllm checkout FETCH_HEAD diff --git a/Dockerfile.310p b/Dockerfile.310p index e6d04df1028..14331c3a751 100644 --- a/Dockerfile.310p +++ b/Dockerfile.310p @@ -33,7 +33,7 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -ARG VLLM_COMMIT=14acf429ac08b6d538ca6feb3e06b6d13895804d +ARG VLLM_COMMIT=29e48707e8144b78dd5d756f793c26a405043f3d RUN git init /vllm-workspace/vllm && \ git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ git -C /vllm-workspace/vllm checkout FETCH_HEAD diff --git a/Dockerfile.310p.openEuler b/Dockerfile.310p.openEuler index 1c438e7ec90..338335af656 100644 --- a/Dockerfile.310p.openEuler +++ b/Dockerfile.310p.openEuler @@ -32,7 +32,7 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -ARG VLLM_COMMIT=14acf429ac08b6d538ca6feb3e06b6d13895804d +ARG VLLM_COMMIT=29e48707e8144b78dd5d756f793c26a405043f3d RUN git init /vllm-workspace/vllm && \ git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ git -C /vllm-workspace/vllm checkout FETCH_HEAD diff --git a/Dockerfile.a3 b/Dockerfile.a3 index 387f91d2728..e36934df215 100644 --- a/Dockerfile.a3 +++ b/Dockerfile.a3 @@ -50,7 +50,7 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -ARG VLLM_COMMIT=14acf429ac08b6d538ca6feb3e06b6d13895804d +ARG VLLM_COMMIT=29e48707e8144b78dd5d756f793c26a405043f3d RUN git init /vllm-workspace/vllm && \ git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ git -C /vllm-workspace/vllm checkout FETCH_HEAD diff --git a/Dockerfile.a3.openEuler b/Dockerfile.a3.openEuler index 8865b64318d..2a0089fb8b9 100644 --- a/Dockerfile.a3.openEuler +++ b/Dockerfile.a3.openEuler @@ -49,7 +49,7 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -ARG VLLM_COMMIT=14acf429ac08b6d538ca6feb3e06b6d13895804d +ARG VLLM_COMMIT=29e48707e8144b78dd5d756f793c26a405043f3d RUN git init /vllm-workspace/vllm && \ git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ git -C /vllm-workspace/vllm checkout FETCH_HEAD diff --git a/Dockerfile.openEuler b/Dockerfile.openEuler index 6725e201dd1..3fd276b763a 100644 --- a/Dockerfile.openEuler +++ b/Dockerfile.openEuler @@ -49,7 +49,7 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -ARG VLLM_COMMIT=14acf429ac08b6d538ca6feb3e06b6d13895804d +ARG VLLM_COMMIT=29e48707e8144b78dd5d756f793c26a405043f3d RUN git init /vllm-workspace/vllm && \ git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ git -C /vllm-workspace/vllm checkout FETCH_HEAD diff --git a/tests/e2e/singlecard/test_cpu_offloading.py b/tests/e2e/singlecard/test_cpu_offloading.py index 61b15597720..73e8aa35591 100644 --- a/tests/e2e/singlecard/test_cpu_offloading.py +++ b/tests/e2e/singlecard/test_cpu_offloading.py @@ -5,6 +5,7 @@ import msgspec import msgspec.msgpack +import pytest import zmq from vllm import LLM, SamplingParams, TokensPrompt from vllm.config import KVEventsConfig, KVTransferConfig @@ -127,6 +128,7 @@ def _accuracy_test(llm: LLM, subscriber: MockSubscriber): assert success_count >= 0.5 * test_count +@pytest.mark.skip(reason="cpu offload connector is deprecated.") def test_cpu_offloading() -> None: """ Tests OffloadingConnector with CPUOffloadingSpec. diff --git a/vllm_ascend/ops/gdn.py b/vllm_ascend/ops/gdn.py new file mode 100644 index 00000000000..f8f50982ef8 --- /dev/null +++ b/vllm_ascend/ops/gdn.py @@ -0,0 +1,415 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import torch_npu +from einops import rearrange +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fla.ops import ( + fused_recurrent_gated_delta_rule, +) +from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd +from vllm.model_executor.layers.mamba.gdn_linear_attn import GatedDeltaNetAttention +from vllm.triton_utils import triton +from vllm.v1.attention.backend import AttentionMetadata # type: ignore +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata +from vllm.v1.attention.backends.utils import PAD_SLOT_ID + +from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector +from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule +from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat +from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update +from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch +from vllm_ascend.ops.triton.mamba.causal_conv1d import causal_conv1d_update_npu +from vllm_ascend.utils import enable_sp + + +def to_int64_tuple(tensor: torch.Tensor) -> tuple: + return tuple(tensor.to(torch.int64).tolist()) + + +class AscendGatedDeltaNetAttention(GatedDeltaNetAttention): + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + """ + Forward pass with three parts: + 1. Input projection + 2. Core attention (custom op) + 3. Output projection + """ + if not self.gqa_interleaved_layout: + mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) + num_tokens = mixed_qkvz.size(0) + qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size + z_size = self.value_dim // self.tp_size + mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) + z = z.reshape(z.size(0), -1, self.head_v_dim) + ba, _ = self.in_proj_ba(hidden_states) + b, a = ba.chunk(2, dim=-1) + + b = b.contiguous() + a = a.contiguous() + else: + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + projected_states_ba, _ = self.in_proj_ba(hidden_states) + num_tokens = projected_states_qkvz.size(0) + + mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( + projected_states_qkvz, + projected_states_ba, + triton.cdiv(self.num_k_heads, self.tp_size), + triton.cdiv(self.num_v_heads, self.tp_size), + self.head_k_dim, + self.head_v_dim, + ) + + # ============================================================ + # Part 2: Core Attention (Custom Op) + # ============================================================ + # Note: we should not use torch.empty here like other attention backends, + # see discussions in https://github.com/vllm-project/vllm/pull/28182 + core_attn_out = torch.zeros( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + torch.ops.vllm.gdn_attention_core( + mixed_qkv, + b, + a, + core_attn_out, + self.prefix, + ) + + # ============================================================ + # Part 3: Output Projection + # ============================================================ + maybe_save_kv_layer_to_connector("", []) + z_shape_og = z.shape + # Reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + output[:num_tokens], _ = self.out_proj(core_attn_out) + + def _forward_core( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + ): + """ + Core attention computation (called by custom op). + """ + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = attn_metadata.num_actual_tokens + num_accepted_tokens = attn_metadata.num_accepted_tokens + + if not enable_sp(): + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + # 1. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 1.1: Process the multi-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = causal_conv1d_update_npu( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0][: attn_metadata.num_spec_decodes], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), + validate_data=False, + ) + + # 1.2: Process the remaining part + if attn_metadata.num_prefills > 0: + if mixed_qkv_non_spec is not None: + conv_weights_T = conv_weights.transpose(0, 1) + activation_num = 1 if self.activation else 0 + mixed_qkv_non_spec = torch.ops._C_ascend.npu_causal_conv1d_custom( + mixed_qkv_non_spec, + conv_weights_T, + conv_state=self_kv_cache[0], + bias_opt=self.conv1d.bias, + query_start_loc_opt=to_int64_tuple(non_spec_query_start_loc), + cache_indices_opt=to_int64_tuple(non_spec_state_indices_tensor), + initial_state_mode_opt=to_int64_tuple(has_initial_state), + num_accepted_tokens_opt=[], + activation_mode=activation_num, + pad_slot_id=PAD_SLOT_ID, + run_mode=0, + ) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update_npu( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[: attn_metadata.num_actual_tokens], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec) + + # 2. Recurrent attention + if self.gqa_interleaved_layout: + # Qwen3Next: torch_npu ops support float16/bf16 ssm_state. + # g/beta are needed for both spec-decode and decode, so compute unconditionally. + g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 2.1: Process the multi-query part + if spec_sequence_masks is not None: + cu_seqlens = spec_query_start_loc[: attn_metadata.num_spec_decodes + 1] + actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + query_spec = l2norm_fwd(query_spec) + key_spec = l2norm_fwd(key_spec) + core_attn_out_spec = torch_npu.npu_recurrent_gated_delta_rule( + query=query_spec.squeeze(0), + key=key_spec.squeeze(0), + value=value_spec.squeeze(0), + g=g_spec.squeeze(0), + beta=beta_spec.squeeze(0), + state=ssm_state, + scale=key_spec.shape[-1] ** -0.5, + actual_seq_lengths=actual_seq_lengths, + ssm_state_indices=spec_state_indices_tensor.flatten(), + num_accepted_tokens=num_accepted_tokens.to(torch.int32), + ).unsqueeze(0) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 2.2: Process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[non_spec_state_indices_tensor].transpose(-1, -2).contiguous() + initial_state[~has_initial_state, ...] = 0 + non_spec_chunked_prefill_meta = getattr(attn_metadata, "non_spec_chunked_prefill_meta", None) + (core_attn_out_non_spec, last_recurrent_state) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + prebuilt_meta=non_spec_chunked_prefill_meta, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + ssm_state[non_spec_state_indices_tensor] = ( + last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype) + ) + elif attn_metadata.num_decodes > 0: + cu_seqlens = non_spec_query_start_loc[: attn_metadata.num_decodes + 1] + actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + query_non_spec = l2norm_fwd(query_non_spec) + key_non_spec = l2norm_fwd(key_non_spec) + core_attn_out_non_spec = torch_npu.npu_recurrent_gated_delta_rule( + query=query_non_spec.squeeze(0), + key=key_non_spec.squeeze(0), + value=value_non_spec.squeeze(0), + g=g_non_spec.squeeze(0) if g_non_spec is not None else g_non_spec, + beta=beta_non_spec.squeeze(0) if beta_non_spec is not None else beta_non_spec, + state=ssm_state, + scale=key_non_spec.shape[-1] ** -0.5, + actual_seq_lengths=actual_seq_lengths, + ssm_state_indices=non_spec_state_indices_tensor, + ).unsqueeze(0) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + else: + # Qwen3.5: torch_npu ops do not support float32 ssm_state, use FLA ops instead. + # NOTE: Once torch_npu supports float32 ssm_state, this branch can be removed. + if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: + g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 2.1: Process the multi-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 2.2: Process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + non_spec_chunked_prefill_meta = getattr(attn_metadata, "non_spec_chunked_prefill_meta", None) + (core_attn_out_non_spec, last_recurrent_state) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + prebuilt_meta=non_spec_chunked_prefill_meta, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + elif attn_metadata.num_decodes > 0: + core_attn_out_spec = None + core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update( + A_log=self.A_log.contiguous(), + dt_bias=self.dt_bias.contiguous(), + q=query_non_spec.contiguous(), + k=key_non_spec.contiguous(), + v=value_non_spec.contiguous(), + a=a.contiguous(), + b=b.contiguous(), + initial_state_source=ssm_state, + initial_state_indices=non_spec_state_indices_tensor, + cu_seqlens=non_spec_query_start_loc, + use_qk_l2norm_in_kernel=True, + softplus_beta=1.0, + softplus_threshold=20.0, + ) + else: + core_attn_out_spec, core_attn_out_non_spec = None, None + maybe_save_kv_layer_to_connector("", []) + + # 3. Merge core attention output + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: + merged_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) + merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) + if not enable_sp(): + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens] + elif spec_sequence_masks is not None: + if not enable_sp(): + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens] + else: + if not enable_sp(): + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens] diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index fcaa87e4283..fdf3e1a2bc2 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -39,7 +39,6 @@ if is_310p(): import vllm_ascend.patch.worker.patch_qwen3_5_310 # noqa else: - import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa import vllm_ascend.patch.worker.patch_qwen3_5 # noqa import vllm_ascend.patch.worker.patch_rejection_sampler # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen3_5.py b/vllm_ascend/patch/worker/patch_qwen3_5.py index 3cf5ff22bb7..055fdfee847 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_5.py +++ b/vllm_ascend/patch/worker/patch_qwen3_5.py @@ -18,319 +18,11 @@ import torch -from einops import rearrange from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule -from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update -from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet +from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer from vllm.model_executor.models.qwen3_next import Qwen3NextAttention -from vllm.v1.attention.backend import AttentionMetadata # type: ignore -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm_ascend.ascend_forward_context import _EXTRA_CTX -from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector -from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update -from vllm_ascend.ops.triton.fla.utils import clear_ssm_states -from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch -from vllm_ascend.utils import enable_sp - - -def to_int64_tuple(t): - t = t.to(torch.int64) - if t.dim() == 0: - return (t.item(),) - return tuple(t.tolist()) - - -class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): - def forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - ): - """ - Forward pass with three parts: - 1. Input projection - 2. Core attention (custom op) - 3. Output projection - """ - - # ============================================================ - # Part 1: Input Projection - # ============================================================ - mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) - num_tokens = mixed_qkvz.size(0) - qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size - z_size = self.value_dim // self.tp_size - mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) - z = z.reshape(z.size(0), -1, self.head_v_dim) - ba, _ = self.in_proj_ba(hidden_states) - b, a = ba.chunk(2, dim=-1) - - b = b.contiguous() - a = a.contiguous() - - # ============================================================ - # Part 2: Core Attention (Custom Op) - # ============================================================ - # Note: we should not use torch.empty here like other attention backends, - # see discussions in https://github.com/vllm-project/vllm/pull/28182 - core_attn_out = torch.zeros( - (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - torch.ops.vllm.gdn_attention_core( - mixed_qkv, - b, - a, - core_attn_out, - self.prefix, - ) - # ============================================================ - # Part 3: Output Projection - # ============================================================ - z_shape_og = z.shape - # Reshape input data into 2D tensor - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") - o_out, _ = self.out_proj(core_attn_out) - actual_num_tokens = o_out.shape[0] - output[:actual_num_tokens] = o_out - - def _forward_core( - self, - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, - core_attn_out: torch.Tensor, - ): - # Core attention computation (called by custom op). - - # NOTE: The processing logic of Qwen3_5GatedDeltaNet is the same as Qwen3NextGatedDeltaNet. - # However, because the ops `torch_npu.npu_recurrent_gated_delta_rule` - # currently does not support `ssm_state` inputs in float32 format, - # we temporarily retain the current _forward_core implementation. - # Once the ops supports float32 `ssm_state`, this patch should be removed. - - forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata - - if attn_metadata is None: - # V1 profile run - return - - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, GDNAttentionMetadata) - has_initial_state = attn_metadata.has_initial_state - spec_query_start_loc = attn_metadata.spec_query_start_loc - non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc - spec_sequence_masks = attn_metadata.spec_sequence_masks - spec_token_indx = attn_metadata.spec_token_indx - non_spec_token_indx = attn_metadata.non_spec_token_indx - spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 - non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - num_actual_tokens = attn_metadata.num_actual_tokens - num_accepted_tokens = attn_metadata.num_accepted_tokens - - if not enable_sp(): - mixed_qkv = mixed_qkv[:num_actual_tokens] - b = b[:num_actual_tokens] - a = a[:num_actual_tokens] - - # 1. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if spec_sequence_masks is not None: - if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: - mixed_qkv_spec = mixed_qkv - mixed_qkv_non_spec = None - else: - mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) - mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) - else: - mixed_qkv_spec = None - mixed_qkv_non_spec = mixed_qkv - - # 1.1: Process the multi-query part - if spec_sequence_masks is not None: - mixed_qkv_spec = causal_conv1d_update( - mixed_qkv_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=spec_state_indices_tensor[:, 0][: attn_metadata.num_spec_decodes], - num_accepted_tokens=num_accepted_tokens, - query_start_loc=spec_query_start_loc, - max_query_len=spec_state_indices_tensor.size(-1), - validate_data=False, - ) - - # 1.2: Process the remaining part - if attn_metadata.num_prefills > 0: - if mixed_qkv_non_spec is not None: - conv_weights_T = conv_weights.transpose(0, 1) - activation_num = 1 if self.activation else 0 - mixed_qkv_non_spec = torch.ops._C_ascend.npu_causal_conv1d_custom( - mixed_qkv_non_spec, - conv_weights_T, - conv_state=self_kv_cache[0], - bias_opt=self.conv1d.bias, - query_start_loc_opt=to_int64_tuple(non_spec_query_start_loc), - cache_indices_opt=to_int64_tuple(non_spec_state_indices_tensor), - initial_state_mode_opt=to_int64_tuple(has_initial_state), - num_accepted_tokens_opt=[], - activation_mode=activation_num, - pad_slot_id=PAD_SLOT_ID, - run_mode=0, - ) - elif attn_metadata.num_decodes > 0: - mixed_qkv_non_spec = causal_conv1d_update( - mixed_qkv_non_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=non_spec_state_indices_tensor[: attn_metadata.num_actual_tokens], - validate_data=True, - ) - else: - mixed_qkv_non_spec = None - query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) - query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec) - - if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: - g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) - if spec_sequence_masks is not None: - if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: - g_spec = g - beta_spec = beta - g_non_spec = None - beta_non_spec = None - else: - g_spec = g.index_select(1, spec_token_indx) - beta_spec = beta.index_select(1, spec_token_indx) - g_non_spec = g.index_select(1, non_spec_token_indx) - beta_non_spec = beta.index_select(1, non_spec_token_indx) - else: - g_spec = None - beta_spec = None - g_non_spec = g - beta_non_spec = beta - - # 2. Recurrent attention - - # 2.1: Process the multi-query part - if spec_sequence_masks is not None: - core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( - q=query_spec, - k=key_spec, - v=value_spec, - g=g_spec, - beta=beta_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], - ssm_state_indices=spec_state_indices_tensor, - num_accepted_tokens=num_accepted_tokens, - use_qk_l2norm_in_kernel=True, - ) - else: - core_attn_out_spec, last_recurrent_state = None, None - - # 2.2: Process the remaining part - if attn_metadata.num_prefills > 0: - initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() - clear_ssm_states(initial_state, has_initial_state) - non_spec_chunked_prefill_meta = getattr( - attn_metadata, - "non_spec_chunked_prefill_meta", - None, - ) - ( - core_attn_out_non_spec, - last_recurrent_state, - ) = chunk_gated_delta_rule( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=initial_state, - output_final_state=True, - cu_seqlens=non_spec_query_start_loc, - prebuilt_meta=non_spec_chunked_prefill_meta, - head_first=False, - use_qk_l2norm_in_kernel=True, - ) - # Init cache - ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype) - elif attn_metadata.num_decodes > 0: - core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], - ssm_state_indices=non_spec_state_indices_tensor, - use_qk_l2norm_in_kernel=True, - ) - else: - core_attn_out_non_spec, last_recurrent_state = None, None - - elif attn_metadata.num_decodes > 0: - core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update( - A_log=self.A_log.contiguous(), - dt_bias=self.dt_bias.contiguous(), - q=query_non_spec.contiguous(), - k=key_non_spec.contiguous(), - v=value_non_spec.contiguous(), - a=a.contiguous(), - b=b.contiguous(), - initial_state_source=ssm_state, - initial_state_indices=non_spec_state_indices_tensor, - cu_seqlens=non_spec_query_start_loc, - use_qk_l2norm_in_kernel=True, - softplus_beta=1.0, - softplus_threshold=20.0, - ) - - # 3. Merge core attention output - if spec_sequence_masks is not None and core_attn_out_non_spec is not None: - merged_out = torch.empty( - (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), - dtype=core_attn_out_non_spec.dtype, - device=core_attn_out_non_spec.device, - ) - merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) - merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) - if not enable_sp(): - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens] - elif spec_sequence_masks is not None: - if not enable_sp(): - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens] - else: - if not enable_sp(): - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens] - maybe_save_kv_layer_to_connector("", []) class AscendQwen3NextAttention(Qwen3NextAttention): @@ -443,7 +135,5 @@ def forward( return hidden_states, residual -Qwen3_5GatedDeltaNet.forward = AscendQwen3_5GatedDeltaNet.forward -Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core -Qwen3NextAttention.forward = AscendQwen3NextAttention.forward Qwen3_5DecoderLayer.forward = AscendQwen3_5DecoderLayer.forward +Qwen3NextAttention.forward = AscendQwen3NextAttention.forward diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py deleted file mode 100644 index cb2c216c729..00000000000 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ /dev/null @@ -1,315 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# from collections.abc import Iterable -# mypy: ignore-errors - -import torch -import torch_npu -from einops import rearrange -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule -from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd -from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update -from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet -from vllm.triton_utils import triton -from vllm.v1.attention.backend import AttentionMetadata # type: ignore -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from vllm.v1.attention.backends.utils import PAD_SLOT_ID - -from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector -from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat -from vllm_ascend.ops.triton.fla.utils import clear_ssm_states -from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch -from vllm_ascend.patch.worker.patch_qwen3_5 import to_int64_tuple -from vllm_ascend.utils import enable_sp - - -class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet): - def forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - ): - """ - Forward pass with three parts: - 1. Input projection - 2. Core attention (custom op) - 3. Output projection - """ - - # ============================================================ - # Part 1: Input Projection - # ============================================================ - projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - projected_states_ba, _ = self.in_proj_ba(hidden_states) - num_tokens = projected_states_qkvz.size(0) - - mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( - projected_states_qkvz, - projected_states_ba, - triton.cdiv(self.num_k_heads, self.tp_size), - triton.cdiv(self.num_v_heads, self.tp_size), - self.head_k_dim, - self.head_v_dim, - ) - - # ============================================================ - # Part 2: Core Attention (Custom Op) - # ============================================================ - # Note: we should not use torch.empty here like other attention backends, - # see discussions in https://github.com/vllm-project/vllm/pull/28182 - core_attn_out = torch.zeros( - (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - torch.ops.vllm.gdn_attention_core( - mixed_qkv, - b, - a, - core_attn_out, - self.prefix, - ) - - # ============================================================ - # Part 3: Output Projection - # ============================================================ - maybe_save_kv_layer_to_connector("", []) - z_shape_og = z.shape - # Reshape input data into 2D tensor - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") - output[:num_tokens], _ = self.out_proj(core_attn_out) - - def _forward_core( - self, - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, - core_attn_out: torch.Tensor, - ): - """ - Core attention computation (called by custom op). - """ - forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata - - if attn_metadata is None: - # V1 profile run - return - - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, GDNAttentionMetadata) - has_initial_state = attn_metadata.has_initial_state - spec_query_start_loc = attn_metadata.spec_query_start_loc - non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc - spec_sequence_masks = attn_metadata.spec_sequence_masks - spec_token_indx = attn_metadata.spec_token_indx - non_spec_token_indx = attn_metadata.non_spec_token_indx - spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 - non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - num_actual_tokens = attn_metadata.num_actual_tokens - num_accepted_tokens = attn_metadata.num_accepted_tokens - - if not enable_sp(): - mixed_qkv = mixed_qkv[:num_actual_tokens] - b = b[:num_actual_tokens] - a = a[:num_actual_tokens] - - # 1. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if spec_sequence_masks is not None: - if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: - mixed_qkv_spec = mixed_qkv - mixed_qkv_non_spec = None - else: - mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) - mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) - else: - mixed_qkv_spec = None - mixed_qkv_non_spec = mixed_qkv - - # 1.1: Process the multi-query part - if spec_sequence_masks is not None: - mixed_qkv_spec = causal_conv1d_update( - mixed_qkv_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=spec_state_indices_tensor[:, 0][: attn_metadata.num_spec_decodes], - num_accepted_tokens=num_accepted_tokens, - query_start_loc=spec_query_start_loc, - max_query_len=spec_state_indices_tensor.size(-1), - validate_data=False, - ) - - # 1.2: Process the remaining part - if attn_metadata.num_prefills > 0: - if mixed_qkv_non_spec is not None: - conv_weights_T = conv_weights.transpose(0, 1) - activation_num = 1 if self.activation else 0 - mixed_qkv_non_spec = torch.ops._C_ascend.npu_causal_conv1d_custom( - mixed_qkv_non_spec, - conv_weights_T, - conv_state=self_kv_cache[0], - bias_opt=self.conv1d.bias, - query_start_loc_opt=to_int64_tuple(non_spec_query_start_loc), - cache_indices_opt=to_int64_tuple(non_spec_state_indices_tensor), - initial_state_mode_opt=to_int64_tuple(has_initial_state), - num_accepted_tokens_opt=[], - activation_mode=activation_num, - pad_slot_id=PAD_SLOT_ID, - run_mode=0, - ) - elif attn_metadata.num_decodes > 0: - mixed_qkv_non_spec = causal_conv1d_update( - mixed_qkv_non_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=non_spec_state_indices_tensor[: attn_metadata.num_actual_tokens], - validate_data=True, - ) - else: - mixed_qkv_non_spec = None - query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) - query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec) - g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) - if spec_sequence_masks is not None: - if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: - g_spec = g - beta_spec = beta - g_non_spec = None - beta_non_spec = None - else: - g_spec = g.index_select(1, spec_token_indx) - beta_spec = beta.index_select(1, spec_token_indx) - g_non_spec = g.index_select(1, non_spec_token_indx) - beta_non_spec = beta.index_select(1, non_spec_token_indx) - else: - g_spec = None - beta_spec = None - g_non_spec = g - beta_non_spec = beta - - # 2. Recurrent attention - # 2.1: Process the multi-query part - if spec_sequence_masks is not None: - cu_seqlens = spec_query_start_loc[: attn_metadata.num_spec_decodes + 1] - actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] - query_spec = l2norm_fwd(query_spec) - key_spec = l2norm_fwd(key_spec) - core_attn_out_spec = torch_npu.npu_recurrent_gated_delta_rule( - query=query_spec.squeeze(0), - key=key_spec.squeeze(0), - value=value_spec.squeeze(0), - g=g_spec.squeeze(0), - beta=beta_spec.squeeze(0), - state=ssm_state, - scale=key_spec.shape[-1] ** -0.5, - actual_seq_lengths=actual_seq_lengths, - ssm_state_indices=spec_state_indices_tensor.flatten(), - num_accepted_tokens=num_accepted_tokens.to(torch.int32), - ).unsqueeze(0) - else: - core_attn_out_spec, last_recurrent_state = None, None - - # 2.2: Process the remaining part - if attn_metadata.num_prefills > 0: - initial_state = ssm_state[non_spec_state_indices_tensor].transpose(-1, -2).contiguous() - - clear_ssm_states(initial_state, has_initial_state) - non_spec_chunked_prefill_meta = getattr( - attn_metadata, - "non_spec_chunked_prefill_meta", - None, - ) - ( - core_attn_out_non_spec, - last_recurrent_state, - ) = chunk_gated_delta_rule( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=initial_state, - output_final_state=True, - cu_seqlens=non_spec_query_start_loc, - prebuilt_meta=non_spec_chunked_prefill_meta, - head_first=False, - use_qk_l2norm_in_kernel=True, - ) - ssm_state[non_spec_state_indices_tensor] = ( - last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype) - ) - - elif attn_metadata.num_decodes > 0: - cu_seqlens = non_spec_query_start_loc[: attn_metadata.num_decodes + 1] - actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] - query_non_spec = l2norm_fwd(query_non_spec) - key_non_spec = l2norm_fwd(key_non_spec) - core_attn_out_non_spec = torch_npu.npu_recurrent_gated_delta_rule( - query=query_non_spec.squeeze(0), - key=key_non_spec.squeeze(0), - value=value_non_spec.squeeze(0), - g=g_non_spec.squeeze(0), - beta=beta_non_spec.squeeze(0), - state=ssm_state, - scale=key_non_spec.shape[-1] ** -0.5, - actual_seq_lengths=actual_seq_lengths, - ssm_state_indices=non_spec_state_indices_tensor, - ).unsqueeze(0) - else: - core_attn_out_non_spec, last_recurrent_state = None, None - - # 3. Merge core attention output - if spec_sequence_masks is not None and core_attn_out_non_spec is not None: - merged_out = torch.empty( - (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), - dtype=core_attn_out_non_spec.dtype, - device=core_attn_out_non_spec.device, - ) - merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) - merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) - if not enable_sp(): - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens] - elif spec_sequence_masks is not None: - if not enable_sp(): - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens] - else: - if not enable_sp(): - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens] - - -Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward -Qwen3NextGatedDeltaNet._forward_core = AscendQwen3Next_GatedDeltaNet._forward_core diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 5f96ba08f51..10eb7622885 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -30,6 +30,7 @@ import regex as re import torch +from transformers import PretrainedConfig from vllm.config import get_current_vllm_config from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -583,7 +584,12 @@ def get_kv_quant_split_factor(self, layer_name, kv_head_dim_list): kv_head_dim_list = [k_quant_head_dim, v_quant_head_dim] return calc_split_factor(kv_head_dim_list) - def maybe_update_config(self, model_name: str, revision: str | None = None) -> None: + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ) -> None: """Load the ModelSlim quantization config from model directory. This method is called by vllm after get_quant_config() returns @@ -599,6 +605,7 @@ def maybe_update_config(self, model_name: str, revision: str | None = None) -> N Args: model_name: Path to the model directory or HuggingFace / ModelScope repo id. + hf_config: The Hugging Face config of the model revision: Optional revision (branch, tag, or commit hash) for remote repos. """ diff --git a/vllm_ascend/spec_decode/draft_proposer.py b/vllm_ascend/spec_decode/draft_proposer.py index a35d965611b..60635ce3c9b 100644 --- a/vllm_ascend/spec_decode/draft_proposer.py +++ b/vllm_ascend/spec_decode/draft_proposer.py @@ -2,8 +2,8 @@ import torch.nn as nn from typing_extensions import override from vllm.config import VllmConfig +from vllm.config.utils import replace from vllm.model_executor.model_loader import get_model -from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer @@ -44,15 +44,28 @@ def _raise_if_draft_tp_mismatch(self): "Please pass 'draft_tensor_parallel_size' in the speculative_config." ) + def _create_draft_vllm_config(self) -> VllmConfig: + base = super()._create_draft_vllm_config() + spec = self.speculative_config + + return replace( + base, + quant_config=None, + parallel_config=replace( + spec.draft_parallel_config, + rank=self.vllm_config.parallel_config.rank, + ), + model_config=spec.draft_model_config, + ) + + @override def _get_model(self) -> nn.Module: - # Draft models may be quantized or on different parallelism, - # so we load them with a modified vllm config from vllm.compilation.backends import set_model_tag - temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config) + draft_vllm_config = self._create_draft_vllm_config() with set_model_tag("draft_model"): model = get_model( - vllm_config=temp_vllm_config, + vllm_config=draft_vllm_config, prefix="draft_model", ) return model diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index fde96bce8d7..ba7852f17d4 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -606,6 +606,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.conv import AscendConv3dLayer from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE, AscendSharedFusedMoE + from vllm_ascend.ops.gdn import AscendGatedDeltaNetAttention from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm, AscendRMSNormGated from vllm_ascend.ops.linear import ( AscendColumnParallelLinear, @@ -658,6 +659,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): "Conv3dLayer": AscendConv3dLayer, "RelPosAttention": AscendRelPosAttention, "CustomQwen2Decoder": AscendCustomQwen2Decoder, + "GatedDeltaNetAttention": AscendGatedDeltaNetAttention, } # 310P: override selected ops with 310P implementations (keep minimal changes outside _310p) diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 3c812aa4432..39dccff35f8 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -104,6 +104,12 @@ def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 self.append_row(block_ids, row_idx) + def clear_row(self, row_idx: int) -> None: + num_blocks = self.num_blocks_per_row[row_idx] + if num_blocks > 0: + self.block_table.np[row_idx, :num_blocks] = 0 + self.num_blocks_per_row[row_idx] = 0 + def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] self.block_table.np[tgt, :num_blocks] = self.block_table.np[src, :num_blocks] @@ -291,6 +297,10 @@ def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) + def clear_row(self, row_idx: int) -> None: + for block_table in self.block_tables: + block_table.clear_row(row_idx) + def move_row(self, src: int, tgt: int) -> None: for block_table in self.block_tables: block_table.move_row(src, tgt)