Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/bot_pr_create.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
steps:
- name: Get vLLM version
run: |
VLLM_COMMIT=4034c3d32e30d01639459edd3ab486f56993876d
VLLM_COMMIT=5b3ba94ab4bd9da739bcc27cdd05505467fa499e
echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> "$GITHUB_ENV"

- name: Checkout repository
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/dockerfiles/Dockerfile.lint
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ RUN apt-get update -y && \

ARG VLLM_REPO=https://github.com/vllm-project/vllm.git
# For lint purpose, actually we need make a main2main matching.
ARG VLLM_COMMIT=4034c3d32e30d01639459edd3ab486f56993876d
ARG VLLM_COMMIT=5b3ba94ab4bd9da739bcc27cdd05505467fa499e
RUN git clone $VLLM_REPO /vllm-workspace/vllm && \
cd /vllm-workspace/vllm && \
git checkout $VLLM_COMMIT
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr_test_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
name: e2e-full
strategy:
matrix:
vllm_version: [4034c3d32e30d01639459edd3ab486f56993876d, v0.17.0]
vllm_version: [5b3ba94ab4bd9da739bcc27cdd05505467fa499e, v0.17.0]
needs: [changes]
if: ${{ needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.e2e_tracker == true }}
uses: ./.github/workflows/_e2e_test.yaml
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pr_test_light.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
lint:
uses: ./.github/workflows/_pre_commit.yml
with:
vllm: 4034c3d32e30d01639459edd3ab486f56993876d
vllm: 5b3ba94ab4bd9da739bcc27cdd05505467fa499e
changes:
runs-on: linux-aarch64-a2b3-0
outputs:
Expand Down Expand Up @@ -90,7 +90,7 @@ jobs:
if: ${{ needs.lint.result == 'success' && (needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.ut_tracker == 'true') }}
strategy:
matrix:
vllm_version: [4034c3d32e30d01639459edd3ab486f56993876d, v0.17.0]
vllm_version: [5b3ba94ab4bd9da739bcc27cdd05505467fa499e, v0.17.0]
uses: ./.github/workflows/_unit_test.yaml
with:
vllm: ${{ matrix.vllm_version }}
Expand All @@ -102,7 +102,7 @@ jobs:
name: e2e-light
strategy:
matrix:
vllm_version: [4034c3d32e30d01639459edd3ab486f56993876d, v0.17.0]
vllm_version: [5b3ba94ab4bd9da739bcc27cdd05505467fa499e, v0.17.0]
# Note (yikun): If CI resource are limited we can split job into two chain jobs
needs: [lint, changes]
# only trigger e2e test after lint passed and the change is e2e related with pull request.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/schedule_codecov_refresh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
name: refresh codecov
strategy:
matrix:
vllm_version: [4034c3d32e30d01639459edd3ab486f56993876d]
vllm_version: [5b3ba94ab4bd9da739bcc27cdd05505467fa499e]
uses: ./.github/workflows/_unit_test.yaml
with:
vllm: ${{ matrix.vllm_version }}
Expand Down
2 changes: 1 addition & 1 deletion docs/source/community/versioning_policy.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ For main branch of vLLM Ascend, we usually make it compatible with the latest vL

| vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu |
|-------------|--------------|------------------|-------------|--------------------|
| main | 4034c3d32e30d01639459edd3ab486f56993876d, v0.17.0 tag | >= 3.10, < 3.12 | 8.5.0 | 2.9.0 / 2.9.0 |
| main | 5b3ba94ab4bd9da739bcc27cdd05505467fa499e, v0.17.0 tag | >= 3.10, < 3.12 | 8.5.0 | 2.9.0 / 2.9.0 |

## Release cadence

Expand Down
47 changes: 47 additions & 0 deletions vllm_ascend/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#
import copy
import functools
import logging
from collections.abc import Callable
from typing import Any

Expand All @@ -33,6 +34,40 @@
from vllm_ascend.ascend_config import AscendCompilationConfig, get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY

logger = logging.getLogger(__name__)


def convert_fake_inputs_to_current_fake_mode(example_inputs: list[Any]) -> list[Any]:
"""Fix for FakeTensorMode mismatch issue in vllm upgrade

The piecewise backend now compiles ranges upfront in __init__, which may use
fake tensors from graph placeholder nodes that have a different FakeTensorMode
than the current tracing context. We need to ensure consistent fake mode.
"""
from torch._guards import detect_fake_mode

current_fake_mode = detect_fake_mode()
if current_fake_mode is not None:
# Convert example_inputs to use the current fake mode if they are fake tensors
# from a different fake mode
converted_inputs = []
for inp in example_inputs:
if isinstance(inp, torch.Tensor):
# Check if this is a fake tensor that needs conversion
if hasattr(inp, "fake_mode") and inp.fake_mode is not current_fake_mode:
# Convert to current fake mode
old_fake_mode = inp.fake_mode
converted_inputs.append(current_fake_mode.from_tensor(inp))
logger.debug("Converting fake tensor from fake_mode %s to %s", old_fake_mode, current_fake_mode)
else:
converted_inputs.append(inp)
else:
converted_inputs.append(inp)
return converted_inputs
else:
logger.warning("detect_fake_mode() returned None. FakeTensorMode mismatch fix may not be applied.")
return example_inputs


def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable:
recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions)
Expand All @@ -49,6 +84,9 @@ def fusion_pass_compile(
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
# Fix for FakeTensorMode mismatch issue in vllm upgrade
example_inputs = convert_fake_inputs_to_current_fake_mode(example_inputs)

def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
graph = current_pass_manager(graph)
Expand Down Expand Up @@ -101,6 +139,15 @@ def npugraph_ex_compile(

npugraph_ex = torchair.get_npu_backend(compiler_config=config)

# Apply graph fusion passes (including GELU replacement) before torchair compilation
# This is needed to replace unsupported operations like aten::gelu with NPU-compatible versions
if COMPILATION_PASS_KEY in compiler_config:
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
graph = current_pass_manager(graph)

# Fix for FakeTensorMode mismatch issue in vllm upgrade
example_inputs = convert_fake_inputs_to_current_fake_mode(example_inputs)

# torch.compile requires the output of the fx graph to be a tuple
if not graph_returns_tuple(graph):
return make_graph_return_tuple(graph, example_inputs, npugraph_ex), None
Expand Down
8 changes: 8 additions & 0 deletions vllm_ascend/compilation/graph_fusion_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,11 @@ def configure(self, config: VllmConfig):
from .passes.sequence_parallelism import AscendSequenceParallelismPass

self.passes.append(AscendSequenceParallelismPass(config))

# GELU replacement pass - needed for models like Whisper that use GELU
# which is not natively supported on NPU. Uses torch_npu.npu_gelu
# which provides exact GELU computation on NPU devices.
if self.ascend_compilation_config.get("replace_gelu", True):
from .passes.gelu_replacement_pass import GeluReplacementPass

self.passes.append(GeluReplacementPass(config))
210 changes: 210 additions & 0 deletions vllm_ascend/compilation/passes/gelu_replacement_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

import torch
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import logger

from vllm_ascend.compilation.passes.base_pattern import BasePattern
from vllm_ascend.utils import vllm_version_is

if vllm_version_is("0.15.0"):
from vllm.compilation.vllm_inductor_pass import VllmInductorPass # type: ignore
else:
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass


class GeluPattern(BasePattern):
"""
Pattern that matches torch.ops.aten.gelu.default and replaces it with
an NPU-compatible implementation using torch_npu.npu_gelu.

This is needed because aten::gelu is not supported on NPU and falls back
to CPU, which causes errors during graph capture due to host-device
synchronization restrictions.
"""

def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)

def get_inputs(self) -> list[torch.Tensor]:
"""
Generate example inputs for the GeluPattern.
"""
x = torch.randn(2, 2048, device="npu", dtype=self.dtype)
return [x]

def get_pattern(self):
def pattern(x: torch.Tensor):
"""
Pattern for standard GELU activation.
Note: We don't pass the approximate argument here to match
the default case where approximate="none" is used implicitly.
"""
return torch.ops.aten.gelu.default(x)

return pattern

def get_replacement(self):
def replacement(x: torch.Tensor):
"""
Replacement that uses NPU-compatible GELU implementation.

Uses torch_npu.npu_gelu which provides exact GELU computation
on NPU devices, avoiding the CPU fallback.
"""
import torch_npu

return torch_npu.npu_gelu(x)

return replacement


class GeluInplacePattern(BasePattern):
"""
Pattern that matches torch.ops.aten.gelu_.default (in-place GELU) and
replaces it with an NPU-compatible implementation.
"""

def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)

def get_inputs(self) -> list[torch.Tensor]:
"""
Generate example inputs for the GeluInplacePattern.
"""
x = torch.randn(2, 2048, device="npu", dtype=self.dtype)
return [x]

def get_pattern(self):
def pattern(x: torch.Tensor):
"""
Pattern for in-place GELU activation.
"""
return torch.ops.aten.gelu_.default(x)

return pattern

def get_replacement(self):
def replacement(x: torch.Tensor):
"""
Replacement that uses NPU-compatible GELU implementation.

Since npu_gelu is not in-place, we copy the result back.
"""
import torch_npu

out = torch_npu.npu_gelu(x)
x.copy_(out)
return x

return replacement


class GeluOutPattern(BasePattern):
"""
Pattern that matches torch.ops.aten.gelu.out (out-of-place GELU with
pre-allocated output tensor) and replaces it with an NPU-compatible
implementation.

This variant is generated by torch.compile when compiling graphs with
upfront compilation (e.g., after vllm commit 5569f5218 which stops lazy
compilation). The signature is:
aten::gelu.out(Tensor self, *, str approximate="none", Tensor(a!) out)
-> Tensor(a!)
"""

def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)

def get_inputs(self) -> list[torch.Tensor]:
"""
Generate example inputs for the GeluOutPattern.
The order must match how the function is called: gelu.out(x, out=out)
"""
x = torch.randn(2, 2048, device="npu", dtype=self.dtype)
out = torch.empty_like(x)
return [x, out]

def get_pattern(self):
def pattern(x: torch.Tensor, out: torch.Tensor):
"""
Pattern for GELU with output tensor.
Matches: torch.ops.aten.gelu.out(x, out=out)
"""
return torch.ops.aten.gelu.out(x, out=out)

return pattern

def get_replacement(self):
def replacement(x: torch.Tensor, out: torch.Tensor):
"""
Replacement that uses NPU-compatible GELU implementation.

Computes npu_gelu and copies the result to the output tensor.
"""
import torch_npu

result = torch_npu.npu_gelu(x)
out.copy_(result)
return out

return replacement


class GeluReplacementPass(VllmInductorPass):
"""
A pass that replaces aten::gelu operations with NPU-compatible
implementations to enable graph capture on Ascend NPU devices.

Handles three variants:
- gelu.default: Standard GELU
- gelu_.default: In-place GELU
- gelu.out: GELU with pre-allocated output tensor
"""

def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="gelu_replacement_pass")

dtype = vllm_config.model_config.dtype
if dtype not in (torch.float16, torch.bfloat16, torch.float32):
logger.debug("GELU replacement not enabled: unsupported dtype %s", dtype)
return

# Register all GELU patterns
GeluPattern(vllm_config).register(self.pattern_match_passes)
GeluInplacePattern(vllm_config).register(self.pattern_match_passes)
GeluOutPattern(vllm_config).register(self.pattern_match_passes)

def __call__(self, graph: torch.fx.Graph) -> None: # type: ignore[override]
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
if self.matched_count > 0:
logger.debug("Replaced %s gelu operations with NPU-compatible version", self.matched_count)
self.end_and_log()

def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
GELU replacement is always applicable for all compile ranges.
"""
return True
1 change: 1 addition & 0 deletions vllm_ascend/patch/platform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import vllm_ascend.patch.platform.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa
import vllm_ascend.patch.platform.patch_torch_accelerator # noqa

if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true":
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
Loading
Loading