Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/dockerfiles/Dockerfile.lint
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ RUN apt-get update -y && \

ARG VLLM_REPO=https://github.com/vllm-project/vllm.git
# For lint purpose, actually we need make a main2main matching.
ARG VLLM_COMMIT=6f786f2c506cb07f4566771fdc62e640e2c4a176
ARG VLLM_COMMIT=d886c26d4d4fef7d079696beb4ece1cfb4b008a8
RUN git init /vllm-workspace/vllm && \
git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \
git -C /vllm-workspace/vllm checkout FETCH_HEAD
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr_test_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jobs:
name: e2e-full
strategy:
matrix:
vllm_version: [6f786f2c506cb07f4566771fdc62e640e2c4a176, v0.19.1]
vllm_version: [d886c26d4d4fef7d079696beb4ece1cfb4b008a8, v0.19.1]
needs: [changes]
if: ${{ needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.e2e_tracker == true }}
uses: ./.github/workflows/_e2e_test.yaml
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pr_test_light.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
lint:
uses: ./.github/workflows/_pre_commit.yml
with:
vllm: 6f786f2c506cb07f4566771fdc62e640e2c4a176
vllm: d886c26d4d4fef7d079696beb4ece1cfb4b008a8
changes:
runs-on: linux-aarch64-a2b3-0
container:
Expand Down Expand Up @@ -154,7 +154,7 @@ jobs:
if: ${{ needs.lint.result == 'success' && needs.changes.outputs.has_tests == 'true' }}
strategy:
matrix:
vllm_version: [6f786f2c506cb07f4566771fdc62e640e2c4a176, v0.19.1]
vllm_version: [d886c26d4d4fef7d079696beb4ece1cfb4b008a8, v0.19.1]
uses: ./.github/workflows/_optional_smart_e2e.yaml
with:
vllm: ${{ matrix.vllm_version }}
Expand All @@ -164,7 +164,7 @@ jobs:
name: e2e-light
strategy:
matrix:
vllm_version: [6f786f2c506cb07f4566771fdc62e640e2c4a176, v0.19.1]
vllm_version: [d886c26d4d4fef7d079696beb4ece1cfb4b008a8, v0.19.1]
# Note (yikun): If CI resource are limited we can split job into two chain jobs
needs: [lint, changes]
# only trigger e2e test after lint passed and the change is e2e related with pull request.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
# CANN image tag
"cann_image_tag": "8.5.1-910b-ubuntu22.04-py3.11",
# vLLM commit hash for main branch
"main_vllm_commit": "6f786f2c506cb07f4566771fdc62e640e2c4a176",
"main_vllm_commit": "d886c26d4d4fef7d079696beb4ece1cfb4b008a8",
# vLLM tag for main branch
"main_vllm_tag": "v0.19.1",
# Python version for main branch
Expand Down
15 changes: 11 additions & 4 deletions tests/ut/ops/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from tests.ut.base import TestBase
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention, IndexerWrapper
from vllm_ascend.utils import vllm_version_is


class TestIndexerWrapper(TestBase):
Expand All @@ -18,8 +19,11 @@ def test_initialization(self):
mock_indexer.topk_tokens = 2048
mock_indexer.q_lora_rank = 1536
mock_indexer.wq_b = nn.Linear(128, 128)
mock_indexer.wk = nn.Linear(128, 128)
mock_indexer.weights_proj = nn.Linear(128, 128)
if vllm_version_is("0.19.1"):
mock_indexer.wk = nn.Linear(128, 128)
mock_indexer.weights_proj = nn.Linear(128, 128)
else:
mock_indexer.wk_weights_proj = nn.Linear(128, 128)
mock_indexer.k_norm = nn.LayerNorm(128)
mock_indexer.softmax_scale = 0.123
mock_indexer.topk_indices_buffer = torch.randn(10)
Expand All @@ -32,8 +36,11 @@ def test_initialization(self):
self.assertEqual(wrapper.topk_tokens, 2048)
self.assertEqual(wrapper.q_lora_rank, 1536)
self.assertIs(wrapper.wq_b, mock_indexer.wq_b)
self.assertIs(wrapper.wk, mock_indexer.wk)
self.assertIs(wrapper.weights_proj, mock_indexer.weights_proj)
if vllm_version_is("0.19.1"):
self.assertIs(wrapper.wk, mock_indexer.wk)
self.assertIs(wrapper.weights_proj, mock_indexer.weights_proj)
else:
self.assertIs(wrapper.wk_weights_proj, mock_indexer.wk_weights_proj)
self.assertIs(wrapper.k_norm, mock_indexer.k_norm)
self.assertEqual(wrapper.softmax_scale, 0.123)

Expand Down
7 changes: 6 additions & 1 deletion vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
is_drafter_moe_model,
is_moe_model,
speculative_enable_dispatch_gmm_combine_decode,
vllm_version_is,
)


Expand Down Expand Up @@ -154,7 +155,11 @@ def set_ascend_forward_context(

dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
dp_meta = forward_context.dp_metadata
if vllm_version_is("0.19.1"):
max_tokens_across_dp = dp_meta.max_tokens_across_dp_cpu.item()
else:
max_tokens_across_dp = dp_meta.num_tokens_across_dp_cpu.max().item()
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
Expand Down
7 changes: 6 additions & 1 deletion vllm_ascend/attention/context_parallel/sfa_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm_ascend.attention.sfa_v1 import AscendSFAImpl, AscendSFAMetadata, AscendSFAMetadataBuilder
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, enabling_mlapo, split_decodes_and_prefills
from vllm_ascend.ops.triton.rope import rope_forward_triton_siso
from vllm_ascend.utils import vllm_version_is

M = TypeVar("M", bound=AscendSFAMetadata)

Expand Down Expand Up @@ -413,7 +414,11 @@ def indexer_select_post_process(
actual_seq_lengths_query: torch.Tensor,
actual_seq_lengths_key: torch.Tensor,
):
weights, _ = self.weights_proj(x)
if vllm_version_is("0.19.1"):
weights, _ = self.weights_proj(x)
else:
kw, _ = self.wk_weights_proj(x)
weights = kw[:, self.head_dim :]

q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
Expand Down
20 changes: 16 additions & 4 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
enable_dsa_cp_with_o_proj_tp,
get_weight_prefetch_method,
maybe_trans_nz,
vllm_version_is,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

Expand Down Expand Up @@ -438,8 +439,11 @@ def __init__(
self.n_head: int = self.indexer.n_head # 64
self.head_dim: int = self.indexer.head_dim # 128
self.wq_b = self.indexer.wq_b
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
if vllm_version_is("0.19.1"):
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
else:
self.wk_weights_proj = self.indexer.wk_weights_proj
self.k_norm = self.indexer.k_norm
self.cp_size = 1
self.is_rope_neox_style = True
Expand Down Expand Up @@ -908,7 +912,11 @@ def indexer_select_pre_process(
cos: torch.Tensor,
sin: torch.Tensor,
):
k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
if vllm_version_is("0.19.1"):
k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
else:
kw, _ = self.wk_weights_proj(x)
k_li = kw[:, : self.head_dim]
k_li = self.k_norm(k_li).unsqueeze(1)
k_li = k_li.view(-1, 1, self.head_dim)

Expand Down Expand Up @@ -953,7 +961,11 @@ def indexer_select_post_process(
actual_seq_lengths_query: torch.Tensor,
actual_seq_lengths_key: torch.Tensor,
):
weights, _ = self.weights_proj(x)
if vllm_version_is("0.19.1"):
weights, _ = self.weights_proj(x)
else:
kw, _ = self.wk_weights_proj(x)
weights = kw[:, self.head_dim :]

q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
Expand Down
12 changes: 8 additions & 4 deletions vllm_ascend/ops/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.utils import is_vl_model, parse_layer_idx
from vllm_ascend.utils import is_vl_model, parse_layer_idx, vllm_version_is


class IndexerWrapper(nn.Module):
"""
A wrapper of Indexer for Deepseek v3.2.
This wrapper is currently used to solve the fp8 hard code issue of vllm's deepseek_v2.py.
It wraps the original Indexer, inherits its module weights
(including wq_b, wk, weights_proj, k_norm)
(including wq_b, wk_weights_proj or wk/weights_proj, k_norm)
while deletes the unused topk_indices_buffer and k_cache to save memory.
TODO: Will be removed once original Indexer supports different quantization methods.
"""
Expand All @@ -54,8 +54,12 @@ def __init__(self, vllm_indexer: nn.Module) -> None:
self.topk_tokens: int = vllm_indexer.topk_tokens # 2048
self.q_lora_rank: int = vllm_indexer.q_lora_rank # 1536
self.wq_b = vllm_indexer.wq_b
self.wk = vllm_indexer.wk
self.weights_proj = vllm_indexer.weights_proj
# upstream ac3dac545 fused wk+weights_proj into wk_weights_proj
if vllm_version_is("0.19.1"):
self.wk = vllm_indexer.wk
self.weights_proj = vllm_indexer.weights_proj
else:
self.wk_weights_proj = vllm_indexer.wk_weights_proj
self.k_norm = vllm_indexer.k_norm
self.softmax_scale = vllm_indexer.softmax_scale
vllm_indexer.topk_indices_buffer = None # delete topk_indices_buffer
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
update_cudagraph_capture_sizes,
is_310p,
enable_sp,
vllm_version_is,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -756,7 +757,10 @@ def set_additional_forward_context(
num_tokens = list(attn_metadata.values())[0].num_actual_tokens
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and dp_metadata is not None:
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
if vllm_version_is("0.19.1"):
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
else:
max_tokens_across_dp = dp_metadata.num_tokens_across_dp_cpu.max().item()
if flash_comm_v1_enabled or flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
Expand Down
14 changes: 12 additions & 2 deletions vllm_ascend/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@
enable_sp,
get_ascend_device_type,
register_ascend_customop,
vllm_version_is,
)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner

if not vllm_version_is("0.19.1"):
from vllm.v1.worker.worker_base import CompilationTimes # noqa: E402

torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402
from vllm.utils.torch_utils import set_random_seed # noqa: E402
Expand Down Expand Up @@ -470,7 +474,7 @@ def load_model(self) -> None:
with context, set_current_vllm_config(self.vllm_config):
self.model_runner.load_model()

def compile_or_warm_up_model(self) -> float:
def compile_or_warm_up_model(self):
# Note: need to adapt for graph mode.
warmup_sizes = (self.vllm_config.compilation_config.compile_sizes or []).copy()
if not self.model_config.enforce_eager:
Expand Down Expand Up @@ -550,7 +554,13 @@ def compile_or_warm_up_model(self) -> float:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
return self.vllm_config.compilation_config.compilation_time
if vllm_version_is("0.19.1"):
return self.vllm_config.compilation_config.compilation_time

return CompilationTimes(
language_model=self.vllm_config.compilation_config.compilation_time,
encoder=self.compilation_config.encoder_compilation_time,
)

def _warm_up_atb(self):
x = torch.rand((2, 4), dtype=torch.float16).npu()
Expand Down
Loading