Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/ut/test_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,16 @@ def test_check_and_update_config_v1_worker_class_selection(
"vllm_ascend.worker.worker_v1.NPUWorker",
)

test_ascend_config = self.mock_ascend_config
test_ascend_config.torchair_graph_config.enabled = True
mock_init_ascend.return_value = test_ascend_config
self.mock_vllm_config.parallel_config.worker_cls = "auto"
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertEqual(
self.mock_vllm_config.parallel_config.worker_cls,
"vllm_ascend.torchair.torchair_worker.NPUTorchairWorker",
)

@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.utils.is_310p", return_value=True)
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
update_aclgraph_sizes(vllm_config)

if parallel_config and parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
if ascend_config.torchair_graph_config.enabled:
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"

if cache_config:
if cache_config.block_size is None:
Expand Down
Empty file.
64 changes: 64 additions & 0 deletions vllm_ascend/torchair/torchair_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from vllm.logger import logger

import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist,
check_torchair_cache_exist,
delete_torchair_cache_file,
read_kv_cache_bytes_from_file)
from vllm_ascend.worker.worker_v1 import NPUWorker


class NPUTorchairWorker(NPUWorker):
"""Torchair worker bases on NPUWorker. Only torchair specified code should be added in this class."""

def determine_available_memory(self) -> int:
"""Override determine_available_memory to use cached torchair kv_cache_bytes."""

available_kv_cache_memory = super().determine_available_memory()

if check_torchair_cache_exist() and check_kv_cache_bytes_cache_exist():
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
torch.distributed.get_rank())
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
logger.info(
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
)
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
return old_kv_cache_bytes
else:
logger.info(
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
)
delete_torchair_cache_file()
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
available_kv_cache_memory -= bytes_floating_tolerance
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory

return available_kv_cache_memory

def _get_max_num_tokens_and_with_prefill(self):
"""Override _get_max_num_tokens_and_with_prefill to update max_num_tokens."""

max_num_tokens, with_prefill = super(
)._get_max_num_tokens_and_with_prefill()
if not with_prefill:
max_num_tokens = self.model_runner.select_torchair_padded_batch_size(
max_num_tokens)
return max_num_tokens, with_prefill
51 changes: 13 additions & 38 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,10 @@
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist,
check_torchair_cache_exist,
delete_torchair_cache_file,
read_kv_cache_bytes_from_file,
sleep_mode_enabled, try_register_lib)
from vllm_ascend.utils import sleep_mode_enabled, try_register_lib
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner


Expand Down Expand Up @@ -180,27 +175,6 @@ def determine_available_memory(self) -> int:
logger.info(
f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}"
)
if get_ascend_config().torchair_graph_config.enabled:
if check_torchair_cache_exist(
) and check_kv_cache_bytes_cache_exist():
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
torch.distributed.get_rank())
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
logger.info(
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
)
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
return old_kv_cache_bytes
else:
logger.info(
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
)
delete_torchair_cache_file()
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
available_kv_cache_memory -= bytes_floating_tolerance
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory

return available_kv_cache_memory

def execute_model(
Expand Down Expand Up @@ -291,19 +265,20 @@ def list_loras(self) -> set[int]:
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)

def execute_dummy_batch(self) -> None:
runner = self.model_runner
def _get_max_num_tokens_and_with_prefill(self):
max_num_tokens = 1
with_prefill = False
if runner.dp_size > 1:
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
if self.model_runner.dp_size > 1:
max_num_tokens, with_prefill = self.model_runner._get_forward_metadata_across_dp(
max_num_tokens, with_prefill)
if runner.torchair_graph_enabled and not with_prefill:
max_num_tokens = runner.select_torchair_padded_batch_size(
max_num_tokens)
runner._dummy_run(max_num_tokens,
is_compile=False,
with_prefill=with_prefill)
return max_num_tokens, with_prefill

def execute_dummy_batch(self) -> None:
max_num_tokens, with_prefill = self._get_max_num_tokens_and_with_prefill(
)
self.model_runner._dummy_run(max_num_tokens,
is_compile=False,
with_prefill=with_prefill)

def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
Expand Down
Loading