Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ on:
continue_on_error:
required: false
type: boolean
default: false
default: true
# The following inputs are used by comment-triggered E2E tests (/e2e <tests>).
# They carry space-separated pytest paths, categorized by runner type.
# Leave empty (default) when running label-triggered full/light suites.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/dockerfiles/Dockerfile.lint
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ RUN apt-get update -y && \

ARG VLLM_REPO=https://github.com/vllm-project/vllm.git
# For lint purpose, actually we need make a main2main matching.
ARG VLLM_COMMIT=6f786f2c506cb07f4566771fdc62e640e2c4a176
ARG VLLM_COMMIT=ccaf5ffaa3e1fb2a081b2c9e403ac0e4dfc142c8
RUN git init /vllm-workspace/vllm && \
git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \
git -C /vllm-workspace/vllm checkout FETCH_HEAD
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr_test_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jobs:
name: e2e-full
strategy:
matrix:
vllm_version: [6f786f2c506cb07f4566771fdc62e640e2c4a176, v0.19.0]
vllm_version: [ccaf5ffaa3e1fb2a081b2c9e403ac0e4dfc142c8, v0.19.0]
needs: [changes]
if: ${{ needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.e2e_tracker == true }}
uses: ./.github/workflows/_e2e_test.yaml
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pr_test_light.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
lint:
uses: ./.github/workflows/_pre_commit.yml
with:
vllm: 6f786f2c506cb07f4566771fdc62e640e2c4a176
vllm: ccaf5ffaa3e1fb2a081b2c9e403ac0e4dfc142c8
changes:
runs-on: linux-aarch64-a2b3-0
outputs:
Expand Down Expand Up @@ -92,7 +92,7 @@ jobs:
if: ${{ needs.lint.result == 'success' && (needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.ut_tracker == 'true') }}
strategy:
matrix:
vllm_version: [6f786f2c506cb07f4566771fdc62e640e2c4a176, v0.19.0]
vllm_version: [ccaf5ffaa3e1fb2a081b2c9e403ac0e4dfc142c8, v0.19.0]
uses: ./.github/workflows/_unit_test.yaml
with:
vllm: ${{ matrix.vllm_version }}
Expand All @@ -104,7 +104,7 @@ jobs:
name: e2e-light
strategy:
matrix:
vllm_version: [6f786f2c506cb07f4566771fdc62e640e2c4a176, v0.19.0]
vllm_version: [ccaf5ffaa3e1fb2a081b2c9e403ac0e4dfc142c8, v0.19.0]
# Note (yikun): If CI resource are limited we can split job into two chain jobs
needs: [lint, changes]
# only trigger e2e test after lint passed and the change is e2e related with pull request.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
# CANN image tag
"cann_image_tag": "8.5.1-910b-ubuntu22.04-py3.11",
# vLLM commit hash for main branch
"main_vllm_commit": "6f786f2c506cb07f4566771fdc62e640e2c4a176",
"main_vllm_commit": "ccaf5ffaa3e1fb2a081b2c9e403ac0e4dfc142c8",
# vLLM tag for main branch
"main_vllm_tag": "v0.19.0",
Comment on lines +83 to 85
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Pull Request title and summary do not adhere to the repository style guide. Please update them according to the following suggestions:

Suggested PR Title:

[Ops][Test][Doc][Misc] Update vLLM to v0.19.0 and handle fused MLA weights

Suggested PR Summary:

### What this PR does / why we need it?
This PR updates the vLLM commit hash to a newer version of v0.19.0 and introduces compatibility changes to handle breaking changes in upstream vLLM, including:
- Fused `wk` and `weights_proj` in MLA/SFA.
- Changes in `dp_metadata` fields.
- Changes in `CompilationTimes` structure.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with updated unit tests supporting the new vLLM version.
References
  1. PR Title and Summary must follow a specific format and be provided in markdown code blocks. (link)

# Python version for main branch
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ numba
torch-npu==2.9.0

arctic-inference==0.1.1
transformers>=4.57.4
transformers>=4.57.4, <5.0
fastapi<0.124.0
triton-ascend==3.2.0
15 changes: 11 additions & 4 deletions tests/ut/ops/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from tests.ut.base import TestBase
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention, IndexerWrapper
from vllm_ascend.utils import vllm_version_is


class TestIndexerWrapper(TestBase):
Expand All @@ -19,8 +20,11 @@ def test_initialization(self):
mock_indexer.topk_tokens = 2048
mock_indexer.q_lora_rank = 1536
mock_indexer.wq_b = nn.Linear(128, 128)
mock_indexer.wk = nn.Linear(128, 128)
mock_indexer.weights_proj = nn.Linear(128, 128)
if vllm_version_is("0.19.0"):
mock_indexer.wk = nn.Linear(128, 128)
mock_indexer.weights_proj = nn.Linear(128, 128)
else:
mock_indexer.wk_weights_proj = nn.Linear(128, 128)
mock_indexer.k_norm = nn.LayerNorm(128)
mock_indexer.softmax_scale = 0.123
mock_indexer.topk_indices_buffer = torch.randn(10)
Expand All @@ -33,8 +37,11 @@ def test_initialization(self):
self.assertEqual(wrapper.topk_tokens, 2048)
self.assertEqual(wrapper.q_lora_rank, 1536)
self.assertIs(wrapper.wq_b, mock_indexer.wq_b)
self.assertIs(wrapper.wk, mock_indexer.wk)
self.assertIs(wrapper.weights_proj, mock_indexer.weights_proj)
if vllm_version_is("0.19.0"):
self.assertIs(wrapper.wk, mock_indexer.wk)
self.assertIs(wrapper.weights_proj, mock_indexer.weights_proj)
else:
self.assertIs(wrapper.wk_weights_proj, mock_indexer.wk_weights_proj)
self.assertIs(wrapper.k_norm, mock_indexer.k_norm)
self.assertEqual(wrapper.softmax_scale, 0.123)

Expand Down
86 changes: 61 additions & 25 deletions vllm_ascend/_310p/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE

from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
Expand All @@ -29,6 +28,9 @@
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import vllm_version_is

if vllm_version_is("0.19.0"):
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE # type: ignore[no-redef]

from .experts_selector import select_experts
from .moe_comm_method import AllGatherCommImpl310

Expand Down Expand Up @@ -118,7 +120,12 @@ def apply(

class AscendFusedMoE310(FusedMoE):
def __init__(self, *args, **kwargs):
is_legacy = vllm_version_is("0.19.0")
if not is_legacy:
_routed_input_transform = kwargs.get("routed_input_transform")
super().__init__(*args, **kwargs)
if not is_legacy:
self.reduce_results = False

self.global_num_experts = kwargs["num_experts"]

Expand Down Expand Up @@ -164,18 +171,29 @@ def __init__(self, *args, **kwargs):

from vllm_ascend.ops.fused_moe.fused_moe import AscendMoERunner

is_legacy = vllm_version_is("0.19.0")
self.runner = AscendMoERunner(
self if is_legacy else self.layer_name,
self.moe_config,
self.router,
self._routed_input_transform,
self.gate if is_legacy else kwargs.pop("gate", None),
self.shared_experts if is_legacy else kwargs.pop("shared_experts", None),
self.quant_method,
self.reduce_results,
self.vllm_config.parallel_config.enable_dbo,
)
if is_legacy:
self.runner = AscendMoERunner(
self,
self.moe_config,
self.router,
self._routed_input_transform,
self.gate,
self.shared_experts,
self.quant_method,
self.reduce_results,
self.vllm_config.parallel_config.enable_dbo,
)
else:
self.runner = AscendMoERunner(
self.layer_name,
self.moe_config,
self.router,
_routed_input_transform,
kwargs.get("gate"),
kwargs.get("shared_experts"),
self.quant_method,
self.vllm_config.parallel_config.enable_dbo,
)

def init_experts_map(self, moe_config):
"""
Expand Down Expand Up @@ -221,6 +239,9 @@ def get_quant_type(self) -> QuantType:
raise RuntimeError("Only Unquant and W8A8 is supported.")
return quant_type

def maybe_init_modular_kernel(self) -> None:
return None

def forward_impl( # type: ignore[override]
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -263,7 +284,10 @@ def forward_impl( # type: ignore[override]
return routed_out


class AscendSharedFusedMoE310(SharedFusedMoE, AscendFusedMoE310):
_SharedFusedMoEBase310 = (SharedFusedMoE, AscendFusedMoE310) if vllm_version_is("0.19.0") else (AscendFusedMoE310,)


class AscendSharedFusedMoE310(*_SharedFusedMoEBase310): # type: ignore[misc]
def __init__(
self,
shared_experts: torch.nn.Module,
Expand All @@ -286,17 +310,29 @@ def __init__(
from vllm_ascend.ops.fused_moe.fused_moe import AscendMoERunner

is_legacy = vllm_version_is("0.19.0")
self.runner = AscendMoERunner(
self if is_legacy else self.layer_name,
self.moe_config,
self.router,
self._routed_input_transform,
self._gate,
self._shared_experts,
self.quant_method,
self.reduce_results,
self.vllm_config.parallel_config.enable_dbo,
)
if is_legacy:
self.runner = AscendMoERunner(
self,
self.moe_config,
self.router,
self._routed_input_transform,
self._gate,
self._shared_experts,
self.quant_method,
self.reduce_results,
self.vllm_config.parallel_config.enable_dbo,
)
else:
self.runner = AscendMoERunner(
self.layer_name,
self.moe_config,
self.router,
self._routed_input_transform,
self._gate,
self._shared_experts,
self.quant_method,
self.vllm_config.parallel_config.enable_dbo,
)

@property
def is_internal_router(self) -> bool:
Expand Down
7 changes: 6 additions & 1 deletion vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
is_drafter_moe_model,
is_moe_model,
speculative_enable_dispatch_gmm_combine_decode,
vllm_version_is,
)


Expand Down Expand Up @@ -153,7 +154,11 @@ def set_ascend_forward_context(

dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
dp_meta = forward_context.dp_metadata
if vllm_version_is("0.19.0"):
max_tokens_across_dp = dp_meta.max_tokens_across_dp_cpu.item()
else:
max_tokens_across_dp = dp_meta.num_tokens_across_dp_cpu.max().item()
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
Expand Down
7 changes: 6 additions & 1 deletion vllm_ascend/attention/context_parallel/sfa_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm_ascend.attention.sfa_v1 import AscendSFAImpl, AscendSFAMetadata, AscendSFAMetadataBuilder
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, enabling_mlapo, split_decodes_and_prefills
from vllm_ascend.ops.triton.rope import rope_forward_triton_siso
from vllm_ascend.utils import vllm_version_is

M = TypeVar("M", bound=AscendSFAMetadata)

Expand Down Expand Up @@ -385,7 +386,11 @@ def indexer_select_post_process(
actual_seq_lengths_query: torch.Tensor,
actual_seq_lengths_key: torch.Tensor,
):
weights, _ = self.weights_proj(x)
if vllm_version_is("0.19.0"):
weights, _ = self.weights_proj(x)
else:
kw, _ = self.wk_weights_proj(x)
weights = kw[:, self.head_dim :]
Comment on lines +389 to +393
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for fused weights appears to be inverted. If v0.19.0 has fused weights, it should use wk_weights_proj.

Suggested change
if vllm_version_is("0.19.0"):
weights, _ = self.weights_proj(x)
else:
kw, _ = self.wk_weights_proj(x)
weights = kw[:, self.head_dim :]
if vllm_version_is("0.19.0"):
kw, _ = self.wk_weights_proj(x)
weights = kw[:, self.head_dim :]
else:
weights, _ = self.weights_proj(x)


q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
Expand Down
21 changes: 17 additions & 4 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
enable_dsa_cp_with_o_proj_tp,
get_weight_prefetch_method,
maybe_trans_nz,
vllm_version_is,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

Expand Down Expand Up @@ -438,8 +439,12 @@ def __init__(
self.n_head: int = self.indexer.n_head # 64
self.head_dim: int = self.indexer.head_dim # 128
self.wq_b = self.indexer.wq_b
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
# upstream ac3dac545 fused wk+weights_proj into wk_weights_proj
if vllm_version_is("0.19.0"):
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
else:
self.wk_weights_proj = self.indexer.wk_weights_proj
Comment on lines +443 to +447
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for fused weights appears to be inverted here, which is inconsistent with the logic used in ascend_forward_context.py. If v0.19.0 is the version that introduced fused weights, it should use wk_weights_proj.

Suggested change
if vllm_version_is("0.19.0"):
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
else:
self.wk_weights_proj = self.indexer.wk_weights_proj
if vllm_version_is("0.19.0"):
self.wk_weights_proj = self.indexer.wk_weights_proj
else:
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj

self.k_norm = self.indexer.k_norm
self.cp_size = 1
self.is_rope_neox_style = True
Expand Down Expand Up @@ -908,7 +913,11 @@ def indexer_select_pre_process(
cos: torch.Tensor,
sin: torch.Tensor,
):
k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
if vllm_version_is("0.19.0"):
k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
else:
kw, _ = self.wk_weights_proj(x)
k_li = kw[:, : self.head_dim]
Comment on lines +916 to +920
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for fused weights appears to be inverted. If v0.19.0 has fused weights, it should use wk_weights_proj.

Suggested change
if vllm_version_is("0.19.0"):
k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
else:
kw, _ = self.wk_weights_proj(x)
k_li = kw[:, : self.head_dim]
if vllm_version_is("0.19.0"):
kw, _ = self.wk_weights_proj(x)
k_li = kw[:, : self.head_dim]
else:
k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]

k_li = self.k_norm(k_li).unsqueeze(1)
k_li = k_li.view(-1, 1, self.head_dim)

Expand Down Expand Up @@ -953,7 +962,11 @@ def indexer_select_post_process(
actual_seq_lengths_query: torch.Tensor,
actual_seq_lengths_key: torch.Tensor,
):
weights, _ = self.weights_proj(x)
if vllm_version_is("0.19.0"):
weights, _ = self.weights_proj(x)
else:
kw, _ = self.wk_weights_proj(x)
weights = kw[:, self.head_dim :]

q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
Expand Down
36 changes: 24 additions & 12 deletions vllm_ascend/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AscendRowParallelLinear,
)
from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding
from vllm_ascend.utils import vllm_version_is


class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
Expand Down Expand Up @@ -184,16 +185,27 @@ def can_replace_layer(
return type(source_layer) is AscendRowParallelLinear


_ASCEND_LORA_CLASSES = (
AscendColumnParallelLinearWithLoRA,
AscendMergedColumnParallelLinearWithLoRA,
AscendRowParallelLinearWithLoRA,
AscendVocabParallelEmbeddingWithLoRA,
AscendQKVParallelLinearWithLoRA,
AscendMergedQKVParallelLinearWithLoRA,
AscendColumnParallelLinearWithShardedLoRA,
AscendMergedColumnParallelLinearWithShardedLoRA,
AscendMergedQKVParallelLinearWithShardedLoRA,
AscendQKVParallelLinearWithShardedLoRA,
AscendRowParallelLinearWithShardedLoRA,
AscendReplicatedLinearWithLoRA,
)


def refresh_all_lora_classes():
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendReplicatedLinearWithLoRA)
if vllm_version_is("0.19.0"):
vllm.lora.utils._all_lora_classes.update(_ASCEND_LORA_CLASSES)
return

vllm.lora.utils._all_lora_classes = tuple(
dict.fromkeys((*_ASCEND_LORA_CLASSES, *vllm.lora.utils._all_lora_classes))
)
Loading
Loading