diff --git a/.github/workflows/dockerfiles/Dockerfile.lint b/.github/workflows/dockerfiles/Dockerfile.lint index 3771a9c246f..1cbc47ba549 100644 --- a/.github/workflows/dockerfiles/Dockerfile.lint +++ b/.github/workflows/dockerfiles/Dockerfile.lint @@ -27,10 +27,8 @@ RUN apt-get update -y && \ ARG VLLM_REPO=https://github.com/vllm-project/vllm.git # For lint purpose, actually we need make a main2main matching. -ARG VLLM_COMMIT=4d51588e2381018348f1022dfa3a7698899805b7 -RUN git init /vllm-workspace/vllm && \ - git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ - git -C /vllm-workspace/vllm checkout FETCH_HEAD +ARG VLLM_TAG=v0.20.1 +RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm # # Install vLLM common dependencies RUN python3 -m pip install -r /vllm-workspace/vllm/requirements/common.txt --extra-index https://download.pytorch.org/whl/cpu/ && \ diff --git a/.github/workflows/pr_test_full.yaml b/.github/workflows/pr_test_full.yaml index d496439de29..04bcdcfaf47 100644 --- a/.github/workflows/pr_test_full.yaml +++ b/.github/workflows/pr_test_full.yaml @@ -80,7 +80,7 @@ jobs: name: e2e-full strategy: matrix: - vllm_version: [4d51588e2381018348f1022dfa3a7698899805b7] + vllm_version: [c7aa186d67b6f051680831418e957c67f34ba7a2, v0.20.1] needs: [changes] if: ${{ needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.e2e_tracker == true }} uses: ./.github/workflows/_e2e_test.yaml @@ -102,7 +102,7 @@ jobs: strategy: fail-fast: false matrix: - vllm_version: [4d51588e2381018348f1022dfa3a7698899805b7] + vllm_version: [v0.20.1] needs: [parse-trigger] if: ${{ needs.parse-trigger.outputs.allowed == 'true' }} uses: ./.github/workflows/_e2e_test.yaml diff --git a/.github/workflows/pr_test_light.yaml b/.github/workflows/pr_test_light.yaml index cc79b6e631f..587ed100457 100644 --- a/.github/workflows/pr_test_light.yaml +++ b/.github/workflows/pr_test_light.yaml @@ -41,7 +41,7 @@ jobs: lint: uses: ./.github/workflows/_pre_commit.yml with: - vllm: 4d51588e2381018348f1022dfa3a7698899805b7 + vllm: c7aa186d67b6f051680831418e957c67f34ba7a2 changes: runs-on: linux-aarch64-a2b3-0 container: @@ -154,7 +154,7 @@ jobs: if: ${{ needs.lint.result == 'success' && needs.changes.outputs.has_tests == 'true' }} strategy: matrix: - vllm_version: [4d51588e2381018348f1022dfa3a7698899805b7] + vllm_version: [c7aa186d67b6f051680831418e957c67f34ba7a2, v0.20.1] uses: ./.github/workflows/_optional_smart_e2e.yaml with: vllm: ${{ matrix.vllm_version }} @@ -164,7 +164,7 @@ jobs: name: e2e-light strategy: matrix: - vllm_version: [4d51588e2381018348f1022dfa3a7698899805b7] + vllm_version: [c7aa186d67b6f051680831418e957c67f34ba7a2, v0.20.1] # Note (yikun): If CI resource are limited we can split job into two chain jobs needs: [lint, changes] # only trigger e2e test after lint passed and the change is e2e related with pull request. diff --git a/.github/workflows/schedule_update_estimated_time.yaml b/.github/workflows/schedule_update_estimated_time.yaml index a757fc54a55..904472e03bb 100644 --- a/.github/workflows/schedule_update_estimated_time.yaml +++ b/.github/workflows/schedule_update_estimated_time.yaml @@ -23,7 +23,7 @@ jobs: name: e2e-test strategy: matrix: - vllm_version: [4d51588e2381018348f1022dfa3a7698899805b7] + vllm_version: [v0.20.1] type: [full, light] uses: ./.github/workflows/_e2e_test.yaml with: diff --git a/.github/workflows/schedule_vllm_e2e_test.yaml b/.github/workflows/schedule_vllm_e2e_test.yaml index a0fede2a0d1..1ba8468c990 100644 --- a/.github/workflows/schedule_vllm_e2e_test.yaml +++ b/.github/workflows/schedule_vllm_e2e_test.yaml @@ -45,7 +45,7 @@ jobs: fail-fast: false matrix: part: [0, 1, 2, 3] - vllm: [4d51588e2381018348f1022dfa3a7698899805b7] + vllm: [v0.20.1] container: image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.1-910b-ubuntu22.04-py3.11 env: diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index 7a1138a8f67..62f43ef0473 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -166,4 +166,4 @@ e2e-multicard-4-cards: - name: tests/e2e/multicard/4-cards/test_pipeline_parallel.py estimated_time: 679 - name: tests/e2e/multicard/4-cards/test_profiling_chunk_performance.py - estimated_time: 1300 + estimated_time: 1300 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index c7e03307a0e..f80c9ea822b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -48,12 +48,8 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -# ARG VLLM_TAG=v0.19.1 -# RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm -ARG VLLM_COMMIT=4d51588e2381018348f1022dfa3a7698899805b7 -RUN git init /vllm-workspace/vllm && \ - git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ - git -C /vllm-workspace/vllm checkout FETCH_HEAD +ARG VLLM_TAG=v0.20.1 +RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm # In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -v -e /vllm-workspace/vllm/[audio] --extra-index https://download.pytorch.org/whl/cpu/ && \ python3 -m pip uninstall -y triton && \ diff --git a/Dockerfile.310p b/Dockerfile.310p index 64e16180dab..15fd78ade6b 100644 --- a/Dockerfile.310p +++ b/Dockerfile.310p @@ -33,12 +33,8 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -# ARG VLLM_TAG=v0.19.1 -# RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm -ARG VLLM_COMMIT=4d51588e2381018348f1022dfa3a7698899805b7 -RUN git init /vllm-workspace/vllm && \ - git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ - git -C /vllm-workspace/vllm checkout FETCH_HEAD +ARG VLLM_TAG=v0.20.1 +RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm # In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -v -e /vllm-workspace/vllm/[audio] --extra-index https://download.pytorch.org/whl/cpu/ && \ python3 -m pip uninstall -y triton && \ diff --git a/Dockerfile.310p.openEuler b/Dockerfile.310p.openEuler index b0e141744a1..58f777f4b2d 100644 --- a/Dockerfile.310p.openEuler +++ b/Dockerfile.310p.openEuler @@ -32,12 +32,8 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -# ARG VLLM_TAG=v0.19.1 -# RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm -ARG VLLM_COMMIT=4d51588e2381018348f1022dfa3a7698899805b7 -RUN git init /vllm-workspace/vllm && \ - git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ - git -C /vllm-workspace/vllm checkout FETCH_HEAD +ARG VLLM_TAG=v0.20.1 +RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm # In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /vllm-workspace/vllm/[audio] --extra-index https://download.pytorch.org/whl/cpu/ && \ python3 -m pip uninstall -y triton && \ diff --git a/Dockerfile.a3 b/Dockerfile.a3 index 736bb9d8b2b..bd30bdac78b 100644 --- a/Dockerfile.a3 +++ b/Dockerfile.a3 @@ -50,12 +50,8 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -# ARG VLLM_TAG=v0.19.1 -# RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm -ARG VLLM_COMMIT=4d51588e2381018348f1022dfa3a7698899805b7 -RUN git init /vllm-workspace/vllm && \ - git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ - git -C /vllm-workspace/vllm checkout FETCH_HEAD +ARG VLLM_TAG=v0.20.1 +RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm # In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -v -e /vllm-workspace/vllm/[audio] --extra-index https://download.pytorch.org/whl/cpu/ && \ python3 -m pip uninstall -y triton && \ diff --git a/Dockerfile.a3.openEuler b/Dockerfile.a3.openEuler index 1720e44b618..d5969e34583 100644 --- a/Dockerfile.a3.openEuler +++ b/Dockerfile.a3.openEuler @@ -49,12 +49,8 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -# ARG VLLM_TAG=v0.19.1 -# RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm -ARG VLLM_COMMIT=4d51588e2381018348f1022dfa3a7698899805b7 -RUN git init /vllm-workspace/vllm && \ - git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ - git -C /vllm-workspace/vllm checkout FETCH_HEAD +ARG VLLM_TAG=v0.20.1 +RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm # In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /vllm-workspace/vllm/[audio] --extra-index https://download.pytorch.org/whl/cpu/ && \ python3 -m pip uninstall -y triton && \ diff --git a/Dockerfile.openEuler b/Dockerfile.openEuler index 8d9c50ca1ce..3722577c515 100644 --- a/Dockerfile.openEuler +++ b/Dockerfile.openEuler @@ -49,12 +49,8 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} && \ # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -# ARG VLLM_TAG=v0.19.1 -# RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm -ARG VLLM_COMMIT=4d51588e2381018348f1022dfa3a7698899805b7 -RUN git init /vllm-workspace/vllm && \ - git -C /vllm-workspace/vllm fetch --depth 1 $VLLM_REPO $VLLM_COMMIT && \ - git -C /vllm-workspace/vllm checkout FETCH_HEAD +ARG VLLM_TAG=v0.20.1 +RUN git clone --depth 1 -b $VLLM_TAG $VLLM_REPO /vllm-workspace/vllm # In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /vllm-workspace/vllm/[audio] --extra-index https://download.pytorch.org/whl/cpu/ && \ python3 -m pip uninstall -y triton && \ diff --git a/docs/source/conf.py b/docs/source/conf.py index 57ab3184375..6bdd0ddaf0b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -81,9 +81,9 @@ # CANN image tag "cann_image_tag": "8.5.1-910b-ubuntu22.04-py3.11", # vLLM commit hash for main branch - "main_vllm_commit": "4d51588e2381018348f1022dfa3a7698899805b7", + "main_vllm_commit": "c7aa186d67b6f051680831418e957c67f34ba7a2", # vLLM tag for main branch - "main_vllm_tag": "v0.19.1", + "main_vllm_tag": "v0.20.1", # Python version for main branch "main_python_version": ">= 3.10, < 3.12", # CANN version for main branch diff --git a/mypy.ini b/mypy.ini index a289ba966c4..de6d10ac1a1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -44,3 +44,6 @@ ignore_missing_imports = True [mypy-jiwer] ignore_missing_imports = True +[mypy-vllm.v1.kv_offload.*] +ignore_missing_imports = True + diff --git a/tests/e2e/multicard/2-cards/test_qwen3_moe.py b/tests/e2e/multicard/2-cards/test_qwen3_moe.py index e6c3376e0e6..decca463361 100644 --- a/tests/e2e/multicard/2-cards/test_qwen3_moe.py +++ b/tests/e2e/multicard/2-cards/test_qwen3_moe.py @@ -25,6 +25,7 @@ from vllm.utils.network_utils import get_open_port from tests.e2e.conftest import RemoteOpenAIServer, VllmRunner +from vllm_ascend.utils import vllm_version_is @patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) @@ -74,6 +75,7 @@ def test_qwen3_moe_distributed_aiv_tp2(): vllm_model.generate_greedy(example_prompts, max_tokens) +@pytest.mark.skipif(vllm_version_is("0.20.1"), reason="no need to support model_runner for v0.20.1") @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [True]) @patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"}) diff --git a/tests/e2e/singlecard/model_runner_v2/test_basic.py b/tests/e2e/singlecard/model_runner_v2/test_basic.py index 3edf1b4efc7..c12730c1c11 100644 --- a/tests/e2e/singlecard/model_runner_v2/test_basic.py +++ b/tests/e2e/singlecard/model_runner_v2/test_basic.py @@ -22,6 +22,7 @@ from vllm import SamplingParams from tests.e2e.conftest import VllmRunner +from vllm_ascend.utils import vllm_version_is MODELS = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"] @@ -29,6 +30,7 @@ EGALE_MODELS = ["vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B"] +@pytest.mark.skipif(vllm_version_is("0.20.1"), reason="no need to support model_runner for v0.20.1") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("enforce_eager", [True]) @@ -63,6 +65,7 @@ def test_qwen3_dense_eager_mode( runner.model.generate(prompts, sampling_params) +@pytest.mark.skipif(vllm_version_is("0.20.1"), reason="no need to support model_runner for v0.20.1") @pytest.mark.parametrize("model", MAIN_MODELS) @pytest.mark.parametrize("eagle_model", EGALE_MODELS) @pytest.mark.parametrize("max_tokens", [32]) @@ -101,6 +104,7 @@ def test_egale_spec_decoding( runner.model.generate(prompts, sampling_params) +@pytest.mark.skipif(vllm_version_is("0.20.1"), reason="no need to support model_runner for v0.20.1") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("enforce_eager", [False]) diff --git a/tests/ut/worker/test_model_runner_v2.py b/tests/ut/worker/test_model_runner_v2.py deleted file mode 100644 index 508098deb2b..00000000000 --- a/tests/ut/worker/test_model_runner_v2.py +++ /dev/null @@ -1,59 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from vllm.v1.worker.gpu.model_runner import GPUModelRunner - -from vllm_ascend.ascend_forward_context import MoECommType, get_mrv2_in_profile_run -from vllm_ascend.worker.v2.model_runner import NPUModelRunner - - -class TestNPUModelRunnerV2(unittest.TestCase): - @staticmethod - def _make_runner(max_num_tokens: int = 16): - runner = NPUModelRunner.__new__(NPUModelRunner) - runner.max_num_tokens = max_num_tokens - runner.vllm_config = MagicMock() - return runner - - def test_profile_run_marks_only_mc2_warmup_dummy_run(self): - runner = self._make_runner(max_num_tokens=16) - observed_runs: list[tuple[int, bool]] = [] - - def fake_base_dummy_run(self, num_tokens, *args, **kwargs): - observed_runs.append((num_tokens, get_mrv2_in_profile_run())) - return None, None - - def fake_base_profile_run(self): - self._dummy_run(self.max_num_tokens, skip_attn=True) - - with ( - patch.object(GPUModelRunner, "_dummy_run", new=fake_base_dummy_run), - patch.object(GPUModelRunner, "profile_run", new=fake_base_profile_run), - patch("vllm_ascend.worker.v2.model_runner.get_mc2_tokens_capacity", return_value=8), - patch("vllm_ascend.worker.v2.model_runner.select_moe_comm_method", return_value=MoECommType.MC2), - ): - runner.profile_run() - - self.assertEqual(observed_runs, [(8, True), (16, True)]) - self.assertFalse(get_mrv2_in_profile_run()) - - def test_profile_run_keeps_normal_dummy_run_outside_profile_override(self): - runner = self._make_runner(max_num_tokens=16) - observed_runs: list[tuple[int, bool]] = [] - - def fake_base_dummy_run(self, num_tokens, *args, **kwargs): - observed_runs.append((num_tokens, get_mrv2_in_profile_run())) - return None, None - - def fake_base_profile_run(self): - self._dummy_run(self.max_num_tokens, skip_attn=True) - - with ( - patch.object(GPUModelRunner, "_dummy_run", new=fake_base_dummy_run), - patch.object(GPUModelRunner, "profile_run", new=fake_base_profile_run), - patch("vllm_ascend.worker.v2.model_runner.get_mc2_tokens_capacity", return_value=32), - patch("vllm_ascend.worker.v2.model_runner.select_moe_comm_method", return_value=MoECommType.MC2), - ): - runner.profile_run() - - self.assertEqual(observed_runs, [(16, True)]) diff --git a/vllm_ascend/core/scheduler_profiling_chunk.py b/vllm_ascend/core/scheduler_profiling_chunk.py index ef8285a77e0..b79aae83ff0 100644 --- a/vllm_ascend/core/scheduler_profiling_chunk.py +++ b/vllm_ascend/core/scheduler_profiling_chunk.py @@ -41,6 +41,7 @@ from vllm.v1.utils import record_function_or_nullcontext from vllm_ascend.core.profiling_chunk_predictor import ProfilingChunkManager +from vllm_ascend.utils import vllm_version_is class ProfilingChunkScheduler(Scheduler): @@ -575,12 +576,16 @@ def schedule(self) -> SchedulerOutput: # noqa: C901 if self.is_encoder_decoder and request.has_encoder_inputs and encoder_inputs_to_schedule: num_encoder_tokens = sum(request.get_num_encoder_embeds(i) for i in encoder_inputs_to_schedule) - if self.scheduler_reserve_full_isl and not self.kv_cache_manager.can_fit_full_sequence( - request, - num_new_computed_tokens=num_new_local_computed_tokens, - new_computed_blocks=new_computed_blocks, - num_external_computed_tokens=num_external_computed_tokens, - num_encoder_tokens=num_encoder_tokens, + if ( + vllm_version_is("0.20.1") + and self.scheduler_reserve_full_isl + and not self.kv_cache_manager.can_fit_full_sequence( + request, + num_new_computed_tokens=num_new_local_computed_tokens, + new_computed_blocks=new_computed_blocks, + num_external_computed_tokens=num_external_computed_tokens, + num_encoder_tokens=num_encoder_tokens, + ) ): if request.has_encoder_inputs: self.encoder_cache_manager.free(request) @@ -595,6 +600,9 @@ def schedule(self) -> SchedulerOutput: # noqa: C901 num_external_computed_tokens=num_external_computed_tokens, delay_cache_blocks=load_kv_async, num_encoder_tokens=num_encoder_tokens, + **( + {} if vllm_version_is("0.20.1") else {"full_sequence_must_fit": self.scheduler_reserve_full_isl} + ), ) if new_blocks is None: diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 329aadf2d95..2c92fb86631 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -19,6 +19,7 @@ import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa import vllm_ascend.patch.platform.patch_kv_cache_utils # noqa +import vllm_ascend.patch.platform.patch_mla_prefill_backend # noqa from vllm_ascend import envs from vllm_ascend.utils import is_310p diff --git a/vllm_ascend/patch/platform/patch_mla_prefill_backend.py b/vllm_ascend/patch/platform/patch_mla_prefill_backend.py new file mode 100644 index 00000000000..c904575cb3f --- /dev/null +++ b/vllm_ascend/patch/platform/patch_mla_prefill_backend.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# PR vllm-project/vllm#32623 introduced a new MLAPrefillBackend abstraction. +# When MLAAttention.__init__ calls get_mla_prefill_backend(), the upstream +# selector sees that Ascend NPU returns None for get_device_capability() and +# falls back to FlashAttnPrefillBackend, which asserts flash_attn_varlen_func +# is available — crashing on Ascend. +# +# Ascend's AscendSFAImpl/AscendMLAImpl handles the full forward pass (including +# prefill) via impl.forward(), so prefill_backend.run_prefill_* is never called. +# We register a no-op AscendMLAPrefillBackend and patch get_mla_prefill_backend +# so that MLAAttention.__init__ completes without error. + +import torch +import vllm.model_executor.layers.attention.mla_attention + +from vllm_ascend.utils import vllm_version_is + +if not vllm_version_is("0.20.1"): + from vllm.v1.attention.backends.mla.prefill.base import MLAPrefillBackend + + class AscendMLAPrefillBackend(MLAPrefillBackend): + @staticmethod + def get_name() -> str: + return "ASCEND" + + @classmethod + def is_available(cls) -> bool: + return True + + def run_prefill_new_tokens( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + return_softmax_lse: bool, + ) -> torch.Tensor: + raise NotImplementedError("Ascend MLA prefill is handled by AscendSFAImpl/AscendMLAImpl") + + def run_prefill_context_chunk( + self, + chunk_idx: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError("Ascend MLA prefill is handled by AscendSFAImpl/AscendMLAImpl") + + vllm.model_executor.layers.attention.mla_attention.get_mla_prefill_backend = ( + lambda vllm_config: AscendMLAPrefillBackend + ) diff --git a/vllm_ascend/patch/worker/patch_v2/patch_triton.py b/vllm_ascend/patch/worker/patch_v2/patch_triton.py index 0352462673c..6dcdb447804 100644 --- a/vllm_ascend/patch/worker/patch_v2/patch_triton.py +++ b/vllm_ascend/patch/worker/patch_v2/patch_triton.py @@ -1,8 +1,9 @@ from vllm.v1.worker.gpu import input_batch, model_runner from vllm.v1.worker.gpu.sample import bad_words, gumbel, logprob, penalties, prompt_logprob, sampler, states -from vllm.v1.worker.gpu.spec_decode import rejection_sampler +from vllm.v1.worker.gpu.spec_decode import probabilistic_rejection_sampler_utils, rejection_sampler from vllm.v1.worker.gpu.spec_decode.eagle import speculator +from vllm_ascend.utils import vllm_version_is from vllm_ascend.worker.v2.input_batch import post_update from vllm_ascend.worker.v2.sample.bad_words import apply_bad_words from vllm_ascend.worker.v2.sample.gumbel import apply_temperature, gumbel_sample @@ -25,3 +26,11 @@ gumbel.apply_temperature = apply_temperature states.apply_temperature = apply_temperature logprob.compute_token_logprobs = compute_token_logprobs + +if not vllm_version_is("0.20.1"): + from vllm_ascend.worker.v2.spec_decode.probabilistic_rejection_sampler_utils import ( + probabilistic_rejection_sample as npu_probabilistic_rejection_sample, + ) + + probabilistic_rejection_sampler_utils.probabilistic_rejection_sample = npu_probabilistic_rejection_sample + rejection_sampler.probabilistic_rejection_sample = npu_probabilistic_rejection_sample diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 8600465ad0d..fa5f4443298 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -50,6 +50,11 @@ def __init__( ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode + # Added for compatibility with InputBatch methods that reference these + # attributes after PR vllm-project/vllm#34668. NPU does not use + # thinking budget, so the holder is always None. + self.thinking_budget_state_holder = None + self.thinking_token_budget_reqs: set[str] = set() self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_batched_tokens = max_num_batched_tokens diff --git a/vllm_ascend/worker/v2/aclgraph_utils.py b/vllm_ascend/worker/v2/aclgraph_utils.py index 444f84d6120..c681cdfd42c 100644 --- a/vllm_ascend/worker/v2/aclgraph_utils.py +++ b/vllm_ascend/worker/v2/aclgraph_utils.py @@ -113,7 +113,7 @@ def capture( """Capture CUDA graphs for model forward pass.""" model = ModelWithContext(model) with communicator_switch(): - super().capture( + return super().capture( model, model_state, input_buffers, diff --git a/vllm_ascend/worker/v2/sample/logprob.py b/vllm_ascend/worker/v2/sample/logprob.py index f157d58cc3f..e2fabd5fad1 100644 --- a/vllm_ascend/worker/v2/sample/logprob.py +++ b/vllm_ascend/worker/v2/sample/logprob.py @@ -22,6 +22,10 @@ from vllm.v1.outputs import LogprobsTensors from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num +from vllm_ascend.utils import vllm_version_is + +if not vllm_version_is("0.20.1"): + from vllm.v1.worker.gpu.sample.logprob import LogprobTokenIdsState @triton.jit @@ -120,6 +124,9 @@ def compute_topk_logprobs( num_logprobs: int, sampled_token_ids: torch.Tensor, cu_num_logits: list[int] | None = None, + logprob_token_ids_state: "LogprobTokenIdsState | None" = None, + expanded_idx_mapping: torch.Tensor | None = None, + max_per_req_token_ids: int = 0, ) -> LogprobsTensors: assert num_logprobs >= 0 batch_size, vocab_size = logits.shape diff --git a/vllm_ascend/worker/v2/spec_decode/__init__.py b/vllm_ascend/worker/v2/spec_decode/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/worker/v2/spec_decode/eagle/aclgraph.py b/vllm_ascend/worker/v2/spec_decode/eagle/aclgraph.py index ce7bb93e07f..aaa7cf93359 100644 --- a/vllm_ascend/worker/v2/spec_decode/eagle/aclgraph.py +++ b/vllm_ascend/worker/v2/spec_decode/eagle/aclgraph.py @@ -14,7 +14,11 @@ from vllm.v1.worker.gpu.cudagraph_utils import BatchExecutionDescriptor from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.model_states.interface import ModelState -from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager +from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import ( + CapturedAttentionState, + DecodeEagleCudaGraphManager, + PrefillEagleCudaGraphManager, +) from vllm.v1.worker.utils import AttentionGroup from vllm_ascend.ascend_forward_context import _EXTRA_CTX @@ -27,7 +31,96 @@ from vllm_ascend.worker.v2.utils import communicator_switch -class EagleAclGraphManager(EagleCudaGraphManager): +class PrefillEagleAclGraphManager(PrefillEagleCudaGraphManager): + """AclGraphManager for Eagle speculative decoding.""" + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + cudagraph_mode: CUDAGraphMode, + decode_query_len: int, + speculator: Any, + ): + super().__init__(vllm_config, device, cudagraph_mode, decode_query_len) + + # set speculator attribute, so we can access attributes speculator + # when call `run_fullgraph` method in CudaGraphManager, + # then we don't need to # copy `propose` method in `AscendEagleSpeculator` class. + self.speculator = speculator + # capture_sizes sorts in ascending order. + self.capture_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) + # vllm-ascend need to update draft graph params of attention backend. + # so we need to set draft graph params before capture full graph. + # `prefill` graph and `decodes` graph are different, `decode_query_len` can be used to distinguish them + self.is_draft_model_prefill = decode_query_len > 1 + if super().needs_capture(): + if self.is_draft_model_prefill: + set_draft_graph_prefill_params(self.capture_sizes) + else: + set_draft_graph_params(self.capture_sizes) + + def capture( + self, + forward_fn: Callable, + full_cg_attn_states: dict[BatchExecutionDescriptor, CapturedAttentionState], + progress_bar_desc: str = "Capturing CUDA graphs", + ) -> None: + """Capture ACL graphs for Eagle.""" + with communicator_switch(), model_capture_wrapper(self.speculator, self.is_draft_model_prefill): + super().capture( + forward_fn, + full_cg_attn_states, + progress_bar_desc, + ) + + def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """Override run_fullgraph to update full graph params in run_fullgraph.""" + num_tokens = desc.num_tokens + if self.is_draft_model_prefill: + logger.info_once(f"draft prefill run_fullgraph with num_tokens={num_tokens}") + else: + logger.info_once(f"draft run_fullgraph with num_tokens={num_tokens}") + + draft_attn_metadatas = self.speculator.build_draft_attn_metadatas(desc.num_reqs, self.is_draft_model_prefill) + + ret = super().run_fullgraph(desc) + + positions = self.speculator.input_buffers.positions[:num_tokens] + # refer to vllm.v1.worker.gpu.dp_utils.sync_cudagraph_and_dp_padding to + # calculate num_tokens_across_dp. + num_tokens_across_dp = torch.full([self.speculator.dp_size], num_tokens, device=self.device) + with set_forward_context( + self.speculator.model_state.attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=desc.cg_mode, + num_tokens_across_dp=num_tokens_across_dp, + batch_descriptor=None, # Full graph model don't need batch_descriptor + slot_mapping=None, + ): + # decide to update draft graph params + _EXTRA_CTX.is_draft_model = True + + # decide to run `prefill` graph or `decodes` graph + _EXTRA_CTX.is_draft_model_prefill = self.is_draft_model_prefill + + forward_context = get_forward_context() + update_full_graph_params( + # FIXME(Ronald1995): support hybrid attn backend + list(self.speculator.attn_backends.values())[0], + self.speculator.update_stream, + forward_context, + num_tokens, + self.vllm_config, + self.speculator.speculative_config, + positions.shape[0], + draft_attn_metadatas=draft_attn_metadatas, + ) + return ret + + +class DecodeEagleAclGraphManager(DecodeEagleCudaGraphManager): """AclGraphManager for Eagle speculative decoding.""" def __init__( diff --git a/vllm_ascend/worker/v2/spec_decode/eagle/speculator.py b/vllm_ascend/worker/v2/spec_decode/eagle/speculator.py index 1f0fc0c7798..aff94865047 100644 --- a/vllm_ascend/worker/v2/spec_decode/eagle/speculator.py +++ b/vllm_ascend/worker/v2/spec_decode/eagle/speculator.py @@ -31,13 +31,13 @@ from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.spec_decode.eagle import speculator as vllm_speculator -from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager -from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator, gumbel_sample, update_eagle_inputs +from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import PrefillEagleCudaGraphManager +from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator, update_eagle_draft_inputs from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.worker.v2.attn_utils import build_attn_metadata from vllm_ascend.worker.v2.input_batch import AscendInputBuffers -from vllm_ascend.worker.v2.spec_decode.eagle.aclgraph import EagleAclGraphManager +from vllm_ascend.worker.v2.spec_decode.eagle.aclgraph import PrefillEagleAclGraphManager class AscendEagleSpeculator(EagleSpeculator): @@ -174,51 +174,42 @@ def generate_draft( """ self._init_decode_attn_metadata(attn_metadata, num_reqs) self._increment_decode_attn_metadata(attn_metadata) - - # NOTE(drslark): following lines (from 145 to 184) come from raw gpu's generate_draft logic - pos = self.input_buffers.positions[:num_reqs] - query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] idx_mapping = self.idx_mapping[:num_reqs] - for step in range(1, self.num_speculative_steps): - # Run the eagle model. - last_hidden_states, hidden_states = self.run_model( - num_tokens_padded, - attn_metadata, - slot_mappings, - num_tokens_across_dp, - cudagraph_runtime_mode, - ) - last_hidden_states = last_hidden_states[:num_reqs] - hidden_states = hidden_states[:num_reqs] - logits = self.model.compute_logits(last_hidden_states) - - # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise - # used for draft and target sampling. - draft_tokens = gumbel_sample( - logits, - idx_mapping, - self.temperature, - self.seeds, - pos + 1, - apply_temperature=True, - processed_logits_out=self.draft_logits[:, step] if self.draft_logits is not None else None, - ) - self.draft_tokens[:num_reqs, step] = draft_tokens - - if step < self.num_speculative_steps - 1: - # Update the inputs for the next step. - update_eagle_inputs( - draft_tokens, - hidden_states, - self.input_buffers, - self.hidden_states, - self.max_model_len, - ) - if attn_metadata is not None: - self.block_tables.compute_slot_mappings(idx_mapping, query_start_loc, pos, num_tokens_padded) - - # npu's own update logic - self._increment_decode_attn_metadata(attn_metadata) + positions = self.input_buffers.positions[:num_reqs] + # Run the eagle model forward pass. + last_hidden_states, hidden_states = self.run_model( + num_tokens_padded, + attn_metadata, + slot_mappings, + num_tokens_across_dp, + cudagraph_runtime_mode, + ) + last_hidden_states = last_hidden_states[:num_reqs] + + # Sample the draft tokens. + logits = self.model.compute_logits(last_hidden_states) + draft_tokens = self._sample_draft( + logits, + idx_mapping, + positions, + self.current_draft_step, + self.draft_logits, + ) + + # Update the inputs for the next step. + update_eagle_draft_inputs( + draft_tokens, + self.current_draft_step, + hidden_states, + self.draft_tokens, + self.hidden_states, + self.input_buffers, + num_reqs, + self.max_model_len, + self.num_speculative_steps, + ) + # npu's own update logic + self._increment_decode_attn_metadata(attn_metadata) @torch.inference_mode() def run_model( @@ -385,13 +376,13 @@ def torch_gather_wrapper(): @contextmanager def graph_manager_wrapper(speculator): """Context manager to override graph manager.""" - original_graph_manager = EagleCudaGraphManager + original_graph_manager = PrefillEagleCudaGraphManager def factory(vllm_config: VllmConfig, device: torch.device, cudagraph_mode: CUDAGraphMode, decode_query_len: int): - return EagleAclGraphManager(vllm_config, device, cudagraph_mode, decode_query_len, speculator) + return PrefillEagleAclGraphManager(vllm_config, device, cudagraph_mode, decode_query_len, speculator) try: - vllm_speculator.EagleCudaGraphManager = factory + vllm_speculator.PrefillEagleCudaGraphManager = factory yield finally: - vllm_speculator.EagleCudaGraphManager = original_graph_manager + vllm_speculator.PrefillEagleCudaGraphManager = original_graph_manager diff --git a/vllm_ascend/worker/v2/spec_decode/probabilistic_rejection_sampler_utils.py b/vllm_ascend/worker/v2/spec_decode/probabilistic_rejection_sampler_utils.py new file mode 100644 index 00000000000..0f734659f42 --- /dev/null +++ b/vllm_ascend/worker/v2/spec_decode/probabilistic_rejection_sampler_utils.py @@ -0,0 +1,469 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/spec_decode/probabilistic_rejection_sampler_utils.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import torch +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import ( + _compute_block_stats_kernel, + _compute_global_lse, + _insert_resampled_kernel, +) + + +@triton.jit +def _npu_gumbel_block_argmax( + logits, + block, + mask, + token_idx, + expanded_idx_mapping_ptr, + temp_ptr, + seeds_ptr, + pos_ptr, + processed_logits_ptr, + processed_logits_stride, + processed_logits_col_ptr, + vocab_size, + APPLY_TEMPERATURE: tl.constexpr, +): + req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx) + temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) + if temp != 0.0 and APPLY_TEMPERATURE: + logits = logits / temp + + if processed_logits_ptr is not None: + if processed_logits_col_ptr is not None: + col = tl.load(processed_logits_col_ptr) + else: + col = 0 + tl.store( + processed_logits_ptr + req_state_idx * processed_logits_stride + col * vocab_size + block, + logits, + mask=mask, + ) + + logits = logits.to(tl.float32) + if temp != 0.0: + seed = tl.load(seeds_ptr + req_state_idx) + # NPU: cast pos to int32 to avoid uint64 in philox (NPU umulhi only + # supports int32/uint32). Position values fit in int32 in practice. + pos = tl.load(pos_ptr + token_idx).to(tl.int32) + gumbel_seed = tl.randint(seed, pos) + # NPU: use tl.rand (float32) instead of tl_rand64 (float64 not supported) + r = tl.rand(gumbel_seed, block).to(tl.float32) + gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20) + logits = tl.where(mask, logits + gumbel_noise, float("-inf")) + + value, idx = tl.max(logits, axis=0, return_indices=True) + return value, idx + + +@triton.jit +def _resample_kernel( + # [num_reqs, num_blocks] + resampled_local_argmax_ptr, + resampled_local_argmax_stride, + # [num_reqs, num_blocks] + resampled_local_max_ptr, + resampled_local_max_stride, + # [num_logits, V] + target_logits_ptr, + target_logits_stride, + # [num_reqs] + target_rejected_logsumexp_ptr, + # [max_num_reqs, num_speculative_steps, V] + draft_logits_ptr, + draft_logits_stride_0, + draft_logits_stride_1, + # [num_reqs] + draft_rejected_logsumexp_ptr, + # [num_reqs] + rejected_step_ptr, + # [num_reqs + 1] + cu_num_logits_ptr, + # [num_logits] + expanded_idx_mapping_ptr, + # [num_logits] + draft_sampled_ptr, + # [max_num_reqs] + temp_ptr, + # [max_num_reqs] + seed_ptr, + # [num_logits] + pos_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, + HAS_DRAFT_LOGITS: tl.constexpr, +): + req_idx = tl.program_id(0) + resample_idx = tl.load(rejected_step_ptr + req_idx) + start_idx = tl.load(cu_num_logits_ptr + req_idx) + end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + resample_token_idx = start_idx + resample_idx + req_state_idx = tl.load(expanded_idx_mapping_ptr + resample_token_idx) + + temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) + is_bonus = resample_token_idx == end_idx - 1 + if temp == 0.0 and not is_bonus: + return + + block_idx = tl.program_id(1) + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + target_logits = tl.load( + target_logits_ptr + resample_token_idx * target_logits_stride + block, + mask=mask, + other=float("-inf"), + ).to(tl.float32) + + if is_bonus: + residual_logits = target_logits + elif HAS_DRAFT_LOGITS: + draft_logits = tl.load( + draft_logits_ptr + req_state_idx * draft_logits_stride_0 + resample_idx * draft_logits_stride_1 + block, + mask=mask, + other=float("-inf"), + ).to(tl.float32) + target_lse = tl.load(target_rejected_logsumexp_ptr + req_idx) + draft_lse = tl.load(draft_rejected_logsumexp_ptr + req_idx) + target_log_probs = target_logits - target_lse + draft_log_probs = draft_logits - draft_lse + ratio = tl.exp(draft_log_probs - target_log_probs) + residual_logits = tl.where( + ratio < 1.0, + target_log_probs + tl.log(1 - ratio), + float("-inf"), + ).to(tl.float32) + else: + rejected_draft_token = tl.load(draft_sampled_ptr + resample_token_idx + 1) + residual_logits = tl.where( + block != rejected_draft_token, + target_logits, + float("-inf"), + ).to(tl.float32) + + value, idx = _npu_gumbel_block_argmax( + residual_logits, + block, + mask, + resample_token_idx, + expanded_idx_mapping_ptr, + temp_ptr, + seed_ptr, + pos_ptr, + None, + 0, + None, + vocab_size, + APPLY_TEMPERATURE=False, + ) + token_id = block_idx * BLOCK_SIZE + idx + tl.store( + resampled_local_argmax_ptr + req_idx * resampled_local_argmax_stride + block_idx, + token_id, + ) + tl.store( + resampled_local_max_ptr + req_idx * resampled_local_max_stride + block_idx, + value, + ) + + +@triton.jit +def _probabilistic_rejection_kernel( + # [num_reqs, num_speculative_steps + 1] + sampled_ptr, + sampled_stride, + # [num_reqs] + rejected_steps_ptr, + # [num_reqs] + target_rejected_logsumexp_ptr, + # [num_reqs] + draft_rejected_logsumexp_ptr, + # [num_logits, V] + target_logits_ptr, + target_logits_stride, + # [num_logits, num_blocks] + target_local_argmax_ptr, + target_local_argmax_stride, + # [num_logits, num_blocks] + target_local_max_ptr, + target_local_max_stride, + # [num_logits, num_blocks] + target_local_sumexp_ptr, + target_local_sumexp_stride, + # [num_logits] + draft_sampled_ptr, + # [max_num_reqs, num_speculative_steps, V] + draft_logits_ptr, + draft_logits_stride_0, + draft_logits_stride_1, + # [num_logits, num_blocks] + draft_local_max_ptr, + draft_local_max_stride, + # [num_logits, num_blocks] + draft_local_sumexp_ptr, + draft_local_sumexp_stride, + # [num_reqs + 1] + cu_num_logits_ptr, + # [num_reqs] + idx_mapping_ptr, + # [max_num_reqs] + temp_ptr, + # [max_num_reqs] + seed_ptr, + # [num_logits] + pos_ptr, + vocab_num_blocks, + PADDED_VOCAB_NUM_BLOCKS: tl.constexpr, + HAS_DRAFT_LOGITS: tl.constexpr, +): + req_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + req_idx) + start_idx = tl.load(cu_num_logits_ptr + req_idx) + end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + num_tokens = end_idx - start_idx + seed = tl.load(seed_ptr + req_state_idx) # noqa: F841 + temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) + + rejected_step = 0 + target_lse = 0.0 + draft_lse = 0.0 + accepted = True + for i in range(num_tokens - 1): + if accepted: + logit_idx = start_idx + i + draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1) + if temp == 0.0: + # Greedy sampling. Accept IFF draft matches target argmax. + # NOTE: Target argmax is stored directly so that resampling + # can be skipped upon rejection. + target_blocks = tl.arange(0, PADDED_VOCAB_NUM_BLOCKS) + target_blocks_mask = target_blocks < vocab_num_blocks + target_local_max = tl.load( + target_local_max_ptr + logit_idx * target_local_max_stride + target_blocks, + mask=target_blocks_mask, + other=float("-inf"), + ) + max_target_block_idx = tl.argmax(target_local_max, axis=0) + target_argmax = tl.load( + target_local_argmax_ptr + logit_idx * target_local_argmax_stride + max_target_block_idx + ) + accepted &= target_argmax == draft_sampled + tl.store(sampled_ptr + req_idx * sampled_stride + i, target_argmax) + else: + target_logit = tl.load(target_logits_ptr + logit_idx * target_logits_stride + draft_sampled).to( + tl.float32 + ) + target_lse = _compute_global_lse( + target_local_max_ptr, + target_local_max_stride, + target_local_sumexp_ptr, + target_local_sumexp_stride, + logit_idx, + vocab_num_blocks, + PADDED_VOCAB_NUM_BLOCKS, + ) + target_log_prob = target_logit - target_lse + # NPU does not support tl_rand64; always accept the draft token. + u = tl.full([], 0.0, dtype=tl.float32) + if HAS_DRAFT_LOGITS: + draft_logit = tl.load( + draft_logits_ptr + + req_state_idx * draft_logits_stride_0 + + i * draft_logits_stride_1 + + draft_sampled + ).to(tl.float32) + draft_lse = _compute_global_lse( + draft_local_max_ptr, + draft_local_max_stride, + draft_local_sumexp_ptr, + draft_local_sumexp_stride, + logit_idx, + vocab_num_blocks, + PADDED_VOCAB_NUM_BLOCKS, + ) + draft_log_prob = draft_logit - draft_lse + else: + # One-hot draft: q(draft_token) = 1, log_q = 0. + draft_log_prob = 0 + # Probability ratio test: p(x) > u * q(x) + # Equivalent log form: log_p(x) > log(u) + log_q(x) + accepted &= target_log_prob > tl.log(u) + draft_log_prob + tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled) + rejected_step += accepted + tl.store(rejected_steps_ptr + req_idx, rejected_step) + tl.store(target_rejected_logsumexp_ptr + req_idx, target_lse) + tl.store(draft_rejected_logsumexp_ptr + req_idx, draft_lse) + + +def probabilistic_rejection_sample( + # [num_logits, V] + target_logits: torch.Tensor, + # [max_num_reqs, num_speculative_steps, V] + draft_logits: torch.Tensor | None, + # [num_logits] + draft_sampled: torch.Tensor, + # [num_reqs + 1] + cu_num_logits: torch.Tensor, + # [num_logits] + pos: torch.Tensor, + # [num_reqs] + idx_mapping: torch.Tensor, + # [num_logits] + expanded_idx_mapping: torch.Tensor, + # [num_logits] + expanded_local_pos: torch.Tensor, + # [max_num_reqs] + temperature: torch.Tensor, + # [max_num_reqs] + seed: torch.Tensor, + num_speculative_steps: int, +) -> tuple[torch.Tensor, torch.Tensor]: + num_reqs = cu_num_logits.shape[0] - 1 + num_logits, vocab_size = target_logits.shape + has_draft_logits = draft_logits is not None + + if draft_logits is None: + # When draft_logits is None, create a dummy tensor so that Triton + # kernel signatures receive valid pointers/strides. The kernels + # will never read from it when HAS_DRAFT_LOGITS=False. + draft_logits = target_logits.new_empty(1, 1, 1) + + # Compute the block-level logits stats, such as target argmax + # (for greedy requests), and target max + softmax exponential + # (for non-greedy requests). + VOCAB_BLOCK_SIZE = 8192 + vocab_num_blocks = triton.cdiv(vocab_size, VOCAB_BLOCK_SIZE) + padded_vocab_num_blocks = triton.next_power_of_2(vocab_num_blocks) + target_local_argmax = target_logits.new_empty(num_logits, vocab_num_blocks, dtype=torch.int64) + target_local_max = target_logits.new_empty(num_logits, vocab_num_blocks, dtype=torch.float32) + target_local_sumexp = target_logits.new_empty(num_logits, vocab_num_blocks, dtype=torch.float32) + draft_local_max = target_logits.new_empty(num_logits, vocab_num_blocks, dtype=torch.float32) + draft_local_sumexp = target_logits.new_empty(num_logits, vocab_num_blocks, dtype=torch.float32) + _compute_block_stats_kernel[(num_logits, vocab_num_blocks)]( + target_local_argmax, + target_local_argmax.stride(0), + target_local_max, + target_local_max.stride(0), + target_local_sumexp, + target_local_sumexp.stride(0), + draft_local_max, + draft_local_max.stride(0), + draft_local_sumexp, + draft_local_sumexp.stride(0), + target_logits, + target_logits.stride(0), + draft_logits, + draft_logits.stride(0), + draft_logits.stride(1), + expanded_idx_mapping, + expanded_local_pos, + temperature, + vocab_size, + num_speculative_steps, + BLOCK_SIZE=VOCAB_BLOCK_SIZE, + HAS_DRAFT_LOGITS=has_draft_logits, + ) + + # Sample up until the first rejected/bonus token, and store + # the step. + sampled = draft_sampled.new_empty(num_reqs, num_speculative_steps + 1, dtype=torch.int64) + num_sampled = sampled.new_empty(num_reqs, dtype=torch.int32) + target_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32) + draft_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32) + _probabilistic_rejection_kernel[(num_reqs,)]( + sampled, + sampled.stride(0), + num_sampled, + target_rejected_logsumexp, + draft_rejected_logsumexp, + target_logits, + target_logits.stride(0), + target_local_argmax, + target_local_argmax.stride(0), + target_local_max, + target_local_max.stride(0), + target_local_sumexp, + target_local_sumexp.stride(0), + draft_sampled, + draft_logits, + draft_logits.stride(0), + draft_logits.stride(1), + draft_local_max, + draft_local_max.stride(0), + draft_local_sumexp, + draft_local_sumexp.stride(0), + cu_num_logits, + idx_mapping, + temperature, + seed, + pos, + vocab_num_blocks, + PADDED_VOCAB_NUM_BLOCKS=padded_vocab_num_blocks, + HAS_DRAFT_LOGITS=has_draft_logits, + num_warps=1, + ) + + # Resample the rejected/bonus tokens. + RESAMPLE_BLOCK_SIZE = 1024 + resample_num_blocks = triton.cdiv(vocab_size, RESAMPLE_BLOCK_SIZE) + padded_resample_num_blocks = triton.next_power_of_2(resample_num_blocks) + resampled_local_argmax = target_logits.new_empty(num_reqs, resample_num_blocks, dtype=torch.int64) + # NPU does not support float64; use float32 for resampled_local_max. + resampled_local_max = target_logits.new_empty(num_reqs, resample_num_blocks, dtype=torch.float32) + _resample_kernel[(num_reqs, resample_num_blocks)]( + resampled_local_argmax, + resampled_local_argmax.stride(0), + resampled_local_max, + resampled_local_max.stride(0), + target_logits, + target_logits.stride(0), + target_rejected_logsumexp, + draft_logits, + draft_logits.stride(0), + draft_logits.stride(1), + draft_rejected_logsumexp, + num_sampled, + cu_num_logits, + expanded_idx_mapping, + draft_sampled, + temperature, + seed, + pos, + vocab_size, + BLOCK_SIZE=RESAMPLE_BLOCK_SIZE, + HAS_DRAFT_LOGITS=has_draft_logits, + ) + + # Insert the resampled tokens into the output sampled. + _insert_resampled_kernel[(num_reqs,)]( + sampled, + sampled.stride(0), + num_sampled, + resampled_local_argmax, + resampled_local_argmax.stride(0), + resampled_local_max, + resampled_local_max.stride(0), + resample_num_blocks, + cu_num_logits, + expanded_idx_mapping, + temperature, + PADDED_RESAMPLE_NUM_BLOCKS=padded_resample_num_blocks, + ) + return sampled, num_sampled diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index c42fbce3f74..722ec077021 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -44,10 +44,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors -from vllm.v1.worker.worker_base import ( - CompilationTimes, # noqa: E402 - WorkerBase, -) +from vllm.v1.worker.worker_base import CompilationTimes, WorkerBase from vllm.v1.worker.workspace import init_workspace_manager import vllm_ascend.envs as envs_ascend