-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[main] add pd transfer for ascend scheduler #2753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9138372
cbcc27c
5f721fc
914611a
919cc03
7095897
e76d630
6384861
e63792d
e22760b
47b7a97
6dbc172
5a24294
225f0bd
87161a0
b6952b8
c74d4c6
f3b3218
396c81f
cb13556
2bf91df
b5e0425
03949b1
30571f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| import torch | ||
| from pytest_mock import MockerFixture | ||
| from vllm.config import SchedulerConfig, VllmConfig | ||
|
|
||
| from tests.ut.base import PytestBase | ||
| from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor | ||
|
|
||
|
|
||
| class TestMinPLogitsProcessorInitFunc(PytestBase): | ||
|
|
||
| def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture): | ||
| device_cpu = torch.device("cpu") | ||
| device_npu = torch.device("npu") | ||
| is_pin_memory = False | ||
| mock_vllm_config = mocker.MagicMock(spec=VllmConfig) | ||
| mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) | ||
| mock_scheduler_config.decode_max_num_seqs = 0 | ||
| mock_scheduler_config.max_num_seqs = 128 | ||
| mock_vllm_config.scheduler_config = mock_scheduler_config | ||
| # torch.zeros/torch.empty returns error on online ut machine, so mock it | ||
| mock_tensor = torch.zeros((256, ), | ||
| dtype=torch.float32, | ||
| pin_memory=False) | ||
| mocker.patch("torch.zeros", return_value=mock_tensor) | ||
| mock_empty_tensor = torch.empty((256, ), dtype=torch.float32) | ||
| mocker.patch("torch.empty", return_value=mock_empty_tensor) | ||
|
|
||
| processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu, | ||
| is_pin_memory) | ||
|
|
||
| assert processor_cpu.min_p is not None | ||
| assert processor_cpu.use_double_tensor is False | ||
| assert processor_cpu.min_p_cpu.shape[0] == 256 | ||
|
|
||
| processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu, | ||
| is_pin_memory) | ||
|
|
||
| assert processor_cpu.min_p is not None | ||
| assert processor_cpu.use_double_tensor is True | ||
| assert processor_cpu.min_p_cpu.shape[0] == 256 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -58,6 +58,15 @@ def __init__( | |
| self.scheduled_req_ids: set[str] = set() | ||
| self.running: list[Request] = [] | ||
|
|
||
| self.finished_prefill_reqs: deque[Request] = deque() | ||
| enable_pd_transfer = getattr(self.scheduler_config, | ||
| 'enable_pd_transfer', False) | ||
| decode_max_num_seqs = getattr(self.scheduler_config, | ||
| 'decode_max_num_seqs', 0) | ||
| self.phase = "" if not enable_pd_transfer else "prefill" | ||
| self.decode_max_num_running_reqs = max(self.max_num_running_reqs, | ||
| decode_max_num_seqs) | ||
|
|
||
| def schedule(self) -> SchedulerOutput: | ||
| if self.scheduler_config.chunked_prefill_enabled: | ||
| return super().schedule() | ||
|
|
@@ -85,9 +94,25 @@ def schedule(self) -> SchedulerOutput: | |
| # and put back at the head of the waiting queue later | ||
| skipped_waiting_requests: deque[Request] = deque() | ||
|
|
||
| if self.phase == "prefill": | ||
| remaining_running_reqs = [] | ||
| for request in self.running: | ||
| # move request has finished prefill to finished_prefill_reqs | ||
| if request.num_tokens > request.num_prompt_tokens: | ||
| self.finished_prefill_reqs.append(request) | ||
| else: | ||
| remaining_running_reqs.append(request) | ||
| self.running = remaining_running_reqs | ||
| # all request prefilled, change phase to decode | ||
| if not self.waiting and not self.running: | ||
| self.phase = "decode" | ||
|
Comment on lines
+107
to
+108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The phase transition from 'prefill' to 'decode' is currently one-way. Once the scheduler enters the 'decode' phase, it never returns to 'prefill'. If new requests arrive while the system is in the 'decode' phase, they will be prefilled and then immediately start decoding, which might not be the most efficient approach for the Ascend hardware this feature is targeting, as it breaks the strict batching of prefill operations. To improve performance for dynamic workloads, consider adding logic to allow the scheduler to switch back to the 'prefill' phase. For instance, you could add a check at the beginning of the if self.phase == "decode" and not self.running and self.waiting:
self.phase = "prefill"This would ensure that if the decoding queue is empty and new requests are waiting, the scheduler can switch back to the more efficient batch prefill mode. |
||
|
|
||
| # Schedule prefill requests first. | ||
| while self.waiting and token_budget > 0: | ||
| if len(self.running) == self.max_num_running_reqs: | ||
| if len(self.running) == (self.decode_max_num_running_reqs | ||
| if self.phase == "decode" else | ||
| self.max_num_running_reqs): | ||
|
|
||
| break | ||
|
|
||
| request = self.waiting[0] | ||
|
|
@@ -247,6 +272,13 @@ def skip_cur_request(): | |
| if skipped_waiting_requests: | ||
| self.waiting.extendleft(skipped_waiting_requests) | ||
|
|
||
| if self.phase == "decode": | ||
| while len( | ||
| self.running | ||
| ) < self.decode_max_num_running_reqs and self.finished_prefill_reqs: | ||
| request = self.finished_prefill_reqs.popleft() | ||
| self.running.append(request) | ||
|
|
||
| # If no prefill requests are scheduled, | ||
| # Schedule decode requests next. | ||
| if len(self.scheduled_req_ids) == 0: | ||
|
|
@@ -350,7 +382,9 @@ def skip_cur_request(): | |
| total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) | ||
| assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens | ||
| assert token_budget >= 0 | ||
| assert len(self.running) <= self.max_num_running_reqs | ||
| assert len( | ||
| self.running | ||
| ) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs | ||
| assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( | ||
| scheduled_running_reqs) <= len(self.running) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import itertools | ||
| from collections.abc import Sequence | ||
| from typing import TYPE_CHECKING, Union | ||
|
|
||
| import torch | ||
| from vllm.logger import init_logger | ||
| from vllm.v1.sample import logits_processor | ||
| from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, | ||
| MinTokensLogitsProcessor) | ||
| from vllm.v1.sample.logits_processor.interface import LogitsProcessor | ||
| from vllm.v1.sample.logits_processor.state import LogitsProcessors | ||
|
|
||
| from vllm_ascend.sample.logits_processor.builtin import \ | ||
| AscendMinPLogitsProcessor | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.config import VllmConfig | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| # Error message when the user tries to initialize vLLM with a pooling model | ||
| # and custom logitsproces | ||
| STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" | ||
| " logits processors.") | ||
|
|
||
| BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ | ||
| MinTokensLogitsProcessor, | ||
| LogitBiasLogitsProcessor, | ||
| AscendMinPLogitsProcessor, | ||
| ] | ||
|
|
||
|
|
||
| def build_logitsprocs( | ||
| vllm_config: "VllmConfig", | ||
| device: torch.device, | ||
| is_pin_memory: bool, | ||
| is_pooling_model: bool, | ||
| custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), | ||
| ) -> LogitsProcessors: | ||
| if is_pooling_model: | ||
| if custom_logitsprocs: | ||
| raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) | ||
| logger.debug("Skipping logits processor loading because pooling models" | ||
| " do not support logits processors.") | ||
| return LogitsProcessors() | ||
| custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs( | ||
| custom_logitsprocs) | ||
| return LogitsProcessors( | ||
| ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( | ||
| BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| import torch | ||
| from vllm.config import VllmConfig | ||
| from vllm.v1.sample.logits_processor import MinPLogitsProcessor | ||
|
|
||
|
|
||
| class AscendMinPLogitsProcessor(MinPLogitsProcessor): | ||
|
|
||
| def __init__(self, vllm_config: "VllmConfig", device: torch.device, | ||
| is_pin_memory: bool): | ||
| super().__init__(vllm_config, device, is_pin_memory) | ||
|
|
||
| decode_max_num_seqs = getattr(vllm_config.scheduler_config, | ||
| 'decode_max_num_seqs', 0) | ||
| if decode_max_num_seqs != 0: | ||
| max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs, | ||
| decode_max_num_seqs) | ||
|
|
||
| self.min_p_count: int = 0 | ||
|
|
||
| self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), | ||
| dtype=torch.float32, | ||
| device="cpu", | ||
| pin_memory=is_pin_memory) | ||
| self.min_p_cpu = self.min_p_cpu_tensor.numpy() | ||
|
|
||
| self.use_double_tensor = torch.device(device).type != "cpu" | ||
|
|
||
| if self.use_double_tensor: | ||
| # Pre-allocated device tensor | ||
| self.min_p_device: torch.Tensor = torch.empty( | ||
| (max_num_reqs, ), dtype=torch.float32, device=device) | ||
| else: | ||
| self.min_p_device = self.min_p_cpu_tensor | ||
| # Current slice of the device tensor | ||
| self.min_p: torch.Tensor = self.min_p_device[:0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Manually setting the internal state
scheduler.phasemakes this test brittle and less representative of real usage. If the initialization logic inAscendScheduler.__init__changes, this test would not catch the regression. A better approach is to initialize the scheduler withenable_pd_transfer=Truein its configuration, which would correctly set the initial phase.To achieve this, you could modify the
create_schedulerhelper method to accept configuration overrides. For example:Then, the test can be updated to:
This change would make the test more robust and also serve to verify the scheduler's initialization logic.