Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/user_guide/configuration/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ The details of each config option are as follows:
| Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine|
| `enable_pd_transfer` | bool | `False` | Whether to enable pd transfer. When using it, decode is started only when prefill of all requests is done. This option only takes effects on offline inference. |
| `decode_max_num_seqs` | int | `0` | Whether to change max_num_seqs of decode phase when enable pd transfer. This option only takes effects when enable_pd_transfer is True. |

ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.

Expand Down
13 changes: 13 additions & 0 deletions tests/ut/core/test_schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,16 @@ def test_invalid_config_without_chunked_prefill(self):
)
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
self.assertIn("max_model_len (4096)", str(context.exception))

def test_initialize_from_config_with_pd_transfer(self):
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
enable_pd_transfer=True,
decode_max_num_seqs=48,
max_num_batched_tokens=4096,
max_model_len=4096,
),
)
self.assertEqual(ascend_config.enable_pd_transfer, True)
self.assertEqual(ascend_config.decode_max_num_seqs, 48)
31 changes: 31 additions & 0 deletions tests/ut/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,34 @@ def test_memory_leak(self):

# Confirm no memory leak.
self.assert_scheduler_empty(scheduler)

def test_scheduler_with_pd_transfer(self):
scheduler = self.create_scheduler()
scheduler.phase = "prefill"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Manually setting the internal state scheduler.phase makes this test brittle and less representative of real usage. If the initialization logic in AscendScheduler.__init__ changes, this test would not catch the regression. A better approach is to initialize the scheduler with enable_pd_transfer=True in its configuration, which would correctly set the initial phase.

To achieve this, you could modify the create_scheduler helper method to accept configuration overrides. For example:

def create_scheduler(self, mock_compute_encoder_budget, scheduler_config_override: Optional[Dict[str, Any]] = None):
    # ... existing setup ...
    scheduler_config = SchedulerConfig(
        # ...
    )
    if scheduler_config_override:
        for key, value in scheduler_config_override.items():
            setattr(scheduler_config, key, value)
    # ... rest of the function ...

Then, the test can be updated to:

scheduler = self.create_scheduler(scheduler_config_override={"enable_pd_transfer": True})
self.assertEqual(scheduler.phase, "prefill")

This change would make the test more robust and also serve to verify the scheduler's initialization logic.

requests = create_requests(num_requests=32)
for request in requests:
scheduler.add_request(request)

# 1st iteration, move 16 requests from waiting to running for prefill
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
first_iter_prefilled_req_num = len(scheduler.running)
self.assertEqual(len(scheduler_output.scheduled_new_reqs),
scheduler.max_num_running_reqs)
self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(scheduler_output.finished_req_ids), 0)

# 2nd iteration, move 16 prefilled requests to finished_prefill_reqs
# and move 16 requests from waiting to running for prefill
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
self.assertEqual(len(scheduler.finished_prefill_reqs),
first_iter_prefilled_req_num)

# 3rd iteration, all requests prefilled, change scheduler phase to decode
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
self.assertEqual(scheduler.phase, "decode")
40 changes: 40 additions & 0 deletions tests/ut/sample/logits_processor/test_builtin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from pytest_mock import MockerFixture
from vllm.config import SchedulerConfig, VllmConfig

from tests.ut.base import PytestBase
from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor


class TestMinPLogitsProcessorInitFunc(PytestBase):

def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture):
device_cpu = torch.device("cpu")
device_npu = torch.device("npu")
is_pin_memory = False
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig)
mock_scheduler_config.decode_max_num_seqs = 0
mock_scheduler_config.max_num_seqs = 128
mock_vllm_config.scheduler_config = mock_scheduler_config
# torch.zeros/torch.empty returns error on online ut machine, so mock it
mock_tensor = torch.zeros((256, ),
dtype=torch.float32,
pin_memory=False)
mocker.patch("torch.zeros", return_value=mock_tensor)
mock_empty_tensor = torch.empty((256, ), dtype=torch.float32)
mocker.patch("torch.empty", return_value=mock_empty_tensor)

processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu,
is_pin_memory)

assert processor_cpu.min_p is not None
assert processor_cpu.use_double_tensor is False
assert processor_cpu.min_p_cpu.shape[0] == 256

processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu,
is_pin_memory)

assert processor_cpu.min_p is not None
assert processor_cpu.use_double_tensor is True
assert processor_cpu.min_p_cpu.shape[0] == 256
4 changes: 4 additions & 0 deletions vllm_ascend/core/schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class AscendSchedulerConfig(SchedulerConfig):
num_scheduler_steps: int = 1
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.scheduler.AscendScheduler")
enable_pd_transfer: bool = False
decode_max_num_seqs: int = 0

@classmethod
def initialize_from_config(
Expand All @@ -45,6 +47,8 @@ def initialize_from_config(
scheduler_config["num_scheduler_steps"] = 1
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
scheduler_config["enable_pd_transfer"] = False
scheduler_config["decode_max_num_seqs"] = 0
# Override params in original SchedulerConfig with params in ascend_scheduler_config
for k, _ in scheduler_config.items():
if hasattr(ascend_scheduler_config, k):
Expand Down
38 changes: 36 additions & 2 deletions vllm_ascend/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def __init__(
self.scheduled_req_ids: set[str] = set()
self.running: list[Request] = []

self.finished_prefill_reqs: deque[Request] = deque()
enable_pd_transfer = getattr(self.scheduler_config,
'enable_pd_transfer', False)
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
self.phase = "" if not enable_pd_transfer else "prefill"
self.decode_max_num_running_reqs = max(self.max_num_running_reqs,
decode_max_num_seqs)

def schedule(self) -> SchedulerOutput:
if self.scheduler_config.chunked_prefill_enabled:
return super().schedule()
Expand Down Expand Up @@ -85,9 +94,25 @@ def schedule(self) -> SchedulerOutput:
# and put back at the head of the waiting queue later
skipped_waiting_requests: deque[Request] = deque()

if self.phase == "prefill":
remaining_running_reqs = []
for request in self.running:
# move request has finished prefill to finished_prefill_reqs
if request.num_tokens > request.num_prompt_tokens:
self.finished_prefill_reqs.append(request)
else:
remaining_running_reqs.append(request)
self.running = remaining_running_reqs
# all request prefilled, change phase to decode
if not self.waiting and not self.running:
self.phase = "decode"
Comment on lines +107 to +108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The phase transition from 'prefill' to 'decode' is currently one-way. Once the scheduler enters the 'decode' phase, it never returns to 'prefill'. If new requests arrive while the system is in the 'decode' phase, they will be prefilled and then immediately start decoding, which might not be the most efficient approach for the Ascend hardware this feature is targeting, as it breaks the strict batching of prefill operations.

To improve performance for dynamic workloads, consider adding logic to allow the scheduler to switch back to the 'prefill' phase. For instance, you could add a check at the beginning of the schedule method:

if self.phase == "decode" and not self.running and self.waiting:
    self.phase = "prefill"

This would ensure that if the decoding queue is empty and new requests are waiting, the scheduler can switch back to the more efficient batch prefill mode.


# Schedule prefill requests first.
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
if len(self.running) == (self.decode_max_num_running_reqs
if self.phase == "decode" else
self.max_num_running_reqs):

break

request = self.waiting[0]
Expand Down Expand Up @@ -247,6 +272,13 @@ def skip_cur_request():
if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests)

if self.phase == "decode":
while len(
self.running
) < self.decode_max_num_running_reqs and self.finished_prefill_reqs:
request = self.finished_prefill_reqs.popleft()
self.running.append(request)

# If no prefill requests are scheduled,
# Schedule decode requests next.
if len(self.scheduled_req_ids) == 0:
Expand Down Expand Up @@ -350,7 +382,9 @@ def skip_cur_request():
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert len(
self.running
) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
scheduled_running_reqs) <= len(self.running)

Expand Down
50 changes: 50 additions & 0 deletions vllm_ascend/sample/logits_processor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import itertools
from collections.abc import Sequence
from typing import TYPE_CHECKING, Union

import torch
from vllm.logger import init_logger
from vllm.v1.sample import logits_processor
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
MinTokensLogitsProcessor)
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.logits_processor.state import LogitsProcessors

from vllm_ascend.sample.logits_processor.builtin import \
AscendMinPLogitsProcessor

if TYPE_CHECKING:
from vllm.config import VllmConfig

logger = init_logger(__name__)

# Error message when the user tries to initialize vLLM with a pooling model
# and custom logitsproces
STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom"
" logits processors.")

BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
MinTokensLogitsProcessor,
LogitBiasLogitsProcessor,
AscendMinPLogitsProcessor,
]


def build_logitsprocs(
vllm_config: "VllmConfig",
device: torch.device,
is_pin_memory: bool,
is_pooling_model: bool,
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
) -> LogitsProcessors:
if is_pooling_model:
if custom_logitsprocs:
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
logger.debug("Skipping logits processor loading because pooling models"
" do not support logits processors.")
return LogitsProcessors()
custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs(
custom_logitsprocs)
return LogitsProcessors(
ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
35 changes: 35 additions & 0 deletions vllm_ascend/sample/logits_processor/builtin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from vllm.config import VllmConfig
from vllm.v1.sample.logits_processor import MinPLogitsProcessor


class AscendMinPLogitsProcessor(MinPLogitsProcessor):

def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
super().__init__(vllm_config, device, is_pin_memory)

decode_max_num_seqs = getattr(vllm_config.scheduler_config,
'decode_max_num_seqs', 0)
if decode_max_num_seqs != 0:
max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs,
decode_max_num_seqs)

self.min_p_count: int = 0

self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=is_pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()

self.use_double_tensor = torch.device(device).type != "cpu"

if self.use_double_tensor:
# Pre-allocated device tensor
self.min_p_device: torch.Tensor = torch.empty(
(max_num_reqs, ), dtype=torch.float32, device=device)
else:
self.min_p_device = self.min_p_cpu_tensor
# Current slice of the device tensor
self.min_p: torch.Tensor = self.min_p_device[:0]
7 changes: 5 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
Expand All @@ -86,6 +85,7 @@
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
Expand Down Expand Up @@ -178,7 +178,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
decode_max_num_seqs)
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.device = device
Expand Down
Loading