From 91383723fc56dd3a977b7205fe9c00ac9fb9067b Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Sat, 30 Aug 2025 15:24:26 +0800 Subject: [PATCH 01/22] add pd liushui Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/core/test_schedule_config.py | 9 ++++++++ vllm_ascend/core/schedule_config.py | 2 ++ vllm_ascend/core/scheduler.py | 31 +++++++++++++++++++++++++-- vllm_ascend/worker/model_runner_v1.py | 2 +- 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/tests/ut/core/test_schedule_config.py b/tests/ut/core/test_schedule_config.py index df36b52523a..4c3b4f3fa94 100644 --- a/tests/ut/core/test_schedule_config.py +++ b/tests/ut/core/test_schedule_config.py @@ -165,3 +165,12 @@ def test_invalid_config_without_chunked_prefill(self): ) self.assertIn("max_num_batched_tokens (2048)", str(context.exception)) self.assertIn("max_model_len (4096)", str(context.exception)) + + def test_initialize_from_config_with_decode_bs(self): + ascend_config = AscendSchedulerConfig.initialize_from_config( + self.basic_scheduler_config, + AscendSchedulerConfig( + decode_batch_size=128, + ), + ) + self.assertEqual(ascend_config.decode_batch_size, 128) diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 4ee02e7ed40..130bc7885d9 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -28,6 +28,7 @@ class AscendSchedulerConfig(SchedulerConfig): num_scheduler_steps: int = 1 scheduler_cls: Union[str, Type[object]] = ( "vllm_ascend.core.scheduler.AscendScheduler") + decode_batch_size = 0 @classmethod def initialize_from_config( @@ -45,6 +46,7 @@ def initialize_from_config( scheduler_config["num_scheduler_steps"] = 1 scheduler_config["scheduler_cls"] = ( "vllm_ascend.core.scheduler.AscendScheduler") + scheduler_config["decode_batch_size"] = 0 # Override params in original SchedulerConfig with params in ascend_scheduler_config for k, _ in scheduler_config.items(): if hasattr(ascend_scheduler_config, k): diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index e4fef281b74..8ac304447b2 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -58,6 +58,10 @@ def __init__( self.scheduled_req_ids: set[str] = set() self.running: list[Request] = [] + self.finished_prefill_reqs: list[Request] = [] + self.phase = "prefill" + self.max_num_decode_running_reqs = self.max_num_running_reqs if vllm_config.scheduler_config.decode_batch_size == 0 else vllm_config.scheduler_config.decode_batch_size + def schedule(self) -> SchedulerOutput: if self.scheduler_config.chunked_prefill_enabled: return super().schedule() @@ -85,9 +89,24 @@ def schedule(self) -> SchedulerOutput: # and put back at the head of the waiting queue later skipped_waiting_requests: deque[Request] = deque() + # if max_num_decode_running_reqs configed, pop request that finished prefill to finished_prefill_reqs + if self.max_num_decode_running_reqs != self.max_num_running_reqs: + req_idx = 0 + while self.phase == "prefill" and req_idx < len(len.running): + request = self.running[req_idx] + if not request in self.finished_prefill_reqs: + if request.num_tokens - request.num_prompt_tokens >= 1: + self.finished_prefill_reqs.append(request) + self.running.remove(request) + continue + else: + self.running.remove(request) + continue + req_idx += 1 + # Schedule prefill requests first. while self.waiting and token_budget > 0: - if len(self.running) == self.max_num_running_reqs: + if len(self.running) == (self.max_num_running_reqs if self.phase == "prefill" else self.max_num_decode_running_reqs): break request = self.waiting[0] @@ -245,6 +264,14 @@ def skip_cur_request(): if skipped_waiting_requests: self.waiting.extendleft(skipped_waiting_requests) + if self.max_num_decode_running_reqs != self.max_num_running_reqs and self.phase == "prefill" and not self.waiting and not self.running: + logger.info("change scheduler phase to pure decode") + self.phase = "decode" + if self.phase == "decode": + while self.finished_prefill_reqs: + request = self.finished_prefill_reqs.pop(0) + self.waiting.append(request) + # If no prefill requests are scheduled, # Schedule decode requests next. if len(self.scheduled_req_ids) == 0: @@ -348,7 +375,7 @@ def skip_cur_request(): total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 - assert len(self.running) <= self.max_num_running_reqs + assert len(self.running) <= self.max_num_running_reqs if self.phase == "prefill" else self.max_num_decode_running_reqs assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( scheduled_running_reqs) <= len(self.running) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7a9fe1b8dce..da53835d0f8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -175,7 +175,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, self.block_size) self.max_num_tokens = self.scheduler_config.max_num_batched_tokens - self.max_num_reqs = self.scheduler_config.max_num_seqs + self.max_num_reqs = self.scheduler_config.max_num_seqs if self.scheduler_config.decode_batch_size == 0 else self.scheduler_config.decode_batch_size self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device From cbcc27c0ce8794be8c3b4a67c3d48d74c377fbaf Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Sat, 30 Aug 2025 16:18:06 +0800 Subject: [PATCH 02/22] refomat Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/core/test_schedule_config.py | 8 +++-- tests/ut/core/test_scheduler.py | 29 ++++++++++++++++ vllm_ascend/core/schedule_config.py | 4 +-- vllm_ascend/core/scheduler.py | 49 +++++++++++++-------------- vllm_ascend/worker/model_runner_v1.py | 2 +- 5 files changed, 61 insertions(+), 31 deletions(-) diff --git a/tests/ut/core/test_schedule_config.py b/tests/ut/core/test_schedule_config.py index 4c3b4f3fa94..fcd790b691d 100644 --- a/tests/ut/core/test_schedule_config.py +++ b/tests/ut/core/test_schedule_config.py @@ -166,11 +166,13 @@ def test_invalid_config_without_chunked_prefill(self): self.assertIn("max_num_batched_tokens (2048)", str(context.exception)) self.assertIn("max_model_len (4096)", str(context.exception)) - def test_initialize_from_config_with_decode_bs(self): + def test_initialize_from_config_with_pd_transfer(self): ascend_config = AscendSchedulerConfig.initialize_from_config( self.basic_scheduler_config, AscendSchedulerConfig( - decode_batch_size=128, + enable_pd_transfer=True, + max_num_batched_tokens=4096, + max_model_len=4096, ), ) - self.assertEqual(ascend_config.decode_batch_size, 128) + self.assertEqual(ascend_config.enable_pd_transfer, True) diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index 1855c805bd8..1fe39e1096b 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -896,3 +896,32 @@ def test_memory_leak(self): # Confirm no memory leak. self.assert_scheduler_empty(scheduler) + + def test_scheduler_with_pd_transfer(self): + scheduler = self.create_scheduler() + scheduler.phase = "prefill" + requests = create_requests(num_requests=32) + for request in requests: + scheduler.add_request(request) + + # 1st iteration, move 16 requests from waiting to running for prefill + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + first_iter_prefilled_req_num = len(scheduler.running) + self.assertEqual(len(scheduler_output.scheduled_new_reqs), scheduler.max_num_running_reqs) + self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(scheduler_output.finished_req_ids), 0) + + # 2nd iteration, move 16 prefilled requests to finished_prefill_reqs + # and move 16 requests from waiting to running for prefill + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + self.assertEqual(len(scheduler.finished_prefill_reqs), first_iter_prefilled_req_num) + + # 3rd iteration, all requests prefilled, change scheduler phase to decode + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + self.assertEqual(scheduler.phase, "decode") \ No newline at end of file diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 130bc7885d9..bdeb9d4e7c2 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -28,7 +28,7 @@ class AscendSchedulerConfig(SchedulerConfig): num_scheduler_steps: int = 1 scheduler_cls: Union[str, Type[object]] = ( "vllm_ascend.core.scheduler.AscendScheduler") - decode_batch_size = 0 + enable_pd_transfer: bool = False @classmethod def initialize_from_config( @@ -46,7 +46,7 @@ def initialize_from_config( scheduler_config["num_scheduler_steps"] = 1 scheduler_config["scheduler_cls"] = ( "vllm_ascend.core.scheduler.AscendScheduler") - scheduler_config["decode_batch_size"] = 0 + scheduler_config["enable_pd_transfer"] = False # Override params in original SchedulerConfig with params in ascend_scheduler_config for k, _ in scheduler_config.items(): if hasattr(ascend_scheduler_config, k): diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 8ac304447b2..cb16c8e5b0f 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -58,9 +58,11 @@ def __init__( self.scheduled_req_ids: set[str] = set() self.running: list[Request] = [] - self.finished_prefill_reqs: list[Request] = [] - self.phase = "prefill" - self.max_num_decode_running_reqs = self.max_num_running_reqs if vllm_config.scheduler_config.decode_batch_size == 0 else vllm_config.scheduler_config.decode_batch_size + self.finished_prefill_reqs: deque[Request] = deque() + enable_pd_transfer = getattr(self.scheduler_config, + 'enable_pd_transfer', + False) + self.phase = "" if not enable_pd_transfer else "prefill" def schedule(self) -> SchedulerOutput: if self.scheduler_config.chunked_prefill_enabled: @@ -89,24 +91,22 @@ def schedule(self) -> SchedulerOutput: # and put back at the head of the waiting queue later skipped_waiting_requests: deque[Request] = deque() - # if max_num_decode_running_reqs configed, pop request that finished prefill to finished_prefill_reqs - if self.max_num_decode_running_reqs != self.max_num_running_reqs: - req_idx = 0 - while self.phase == "prefill" and req_idx < len(len.running): - request = self.running[req_idx] - if not request in self.finished_prefill_reqs: - if request.num_tokens - request.num_prompt_tokens >= 1: - self.finished_prefill_reqs.append(request) - self.running.remove(request) - continue + if self.phase == "prefill": + remaining_running_reqs = [] + for request in self.running: + # move request has finished prefill to finished_prefill_reqs + if request.num_tokens > request.num_prompt_tokens: + self.finished_prefill_reqs.append(request) else: - self.running.remove(request) - continue - req_idx += 1 + remaining_running_reqs.append(request) + self.running = remaining_running_reqs + # all request prefilled, change phase to decode + if not self.waiting and not self.running: + self.phase = "decode" # Schedule prefill requests first. while self.waiting and token_budget > 0: - if len(self.running) == (self.max_num_running_reqs if self.phase == "prefill" else self.max_num_decode_running_reqs): + if len(self.running) == self.max_num_running_reqs: break request = self.waiting[0] @@ -264,14 +264,13 @@ def skip_cur_request(): if skipped_waiting_requests: self.waiting.extendleft(skipped_waiting_requests) - if self.max_num_decode_running_reqs != self.max_num_running_reqs and self.phase == "prefill" and not self.waiting and not self.running: - logger.info("change scheduler phase to pure decode") - self.phase = "decode" if self.phase == "decode": - while self.finished_prefill_reqs: - request = self.finished_prefill_reqs.pop(0) - self.waiting.append(request) - + while len( + self.running + ) < self.max_num_running_reqs and self.finished_prefill_reqs: + request = self.finished_prefill_reqs.popleft() + self.running.append(request) + # If no prefill requests are scheduled, # Schedule decode requests next. if len(self.scheduled_req_ids) == 0: @@ -375,7 +374,7 @@ def skip_cur_request(): total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 - assert len(self.running) <= self.max_num_running_reqs if self.phase == "prefill" else self.max_num_decode_running_reqs + assert len(self.running) <= self.max_num_running_reqs assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( scheduled_running_reqs) <= len(self.running) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index da53835d0f8..7a9fe1b8dce 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -175,7 +175,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, self.block_size) self.max_num_tokens = self.scheduler_config.max_num_batched_tokens - self.max_num_reqs = self.scheduler_config.max_num_seqs if self.scheduler_config.decode_batch_size == 0 else self.scheduler_config.decode_batch_size + self.max_num_reqs = self.scheduler_config.max_num_seqs self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device From 5f721fc9cbe42661ec022ad3bbf4540f445ab657 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Thu, 4 Sep 2025 19:11:34 +0800 Subject: [PATCH 03/22] fix codecheck Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/core/test_scheduler.py | 10 ++++++---- vllm_ascend/core/scheduler.py | 3 +-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index 1fe39e1096b..41e866c828e 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -909,19 +909,21 @@ def test_scheduler_with_pd_transfer(self): model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) first_iter_prefilled_req_num = len(scheduler.running) - self.assertEqual(len(scheduler_output.scheduled_new_reqs), scheduler.max_num_running_reqs) + self.assertEqual(len(scheduler_output.scheduled_new_reqs), + scheduler.max_num_running_reqs) self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0) self.assertEqual(len(scheduler_output.finished_req_ids), 0) - # 2nd iteration, move 16 prefilled requests to finished_prefill_reqs + # 2nd iteration, move 16 prefilled requests to finished_prefill_reqs # and move 16 requests from waiting to running for prefill scheduler_output = scheduler.schedule() model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) - self.assertEqual(len(scheduler.finished_prefill_reqs), first_iter_prefilled_req_num) + self.assertEqual(len(scheduler.finished_prefill_reqs), + first_iter_prefilled_req_num) # 3rd iteration, all requests prefilled, change scheduler phase to decode scheduler_output = scheduler.schedule() model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) - self.assertEqual(scheduler.phase, "decode") \ No newline at end of file + self.assertEqual(scheduler.phase, "decode") diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index cb16c8e5b0f..d453d5aebbe 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -60,8 +60,7 @@ def __init__( self.finished_prefill_reqs: deque[Request] = deque() enable_pd_transfer = getattr(self.scheduler_config, - 'enable_pd_transfer', - False) + 'enable_pd_transfer', False) self.phase = "" if not enable_pd_transfer else "prefill" def schedule(self) -> SchedulerOutput: From 919cc03205481e3ae2a720a86774c3313138307d Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Thu, 4 Sep 2025 20:35:49 +0800 Subject: [PATCH 04/22] fix codecheck Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/core/test_scheduler.py | 4 ++-- vllm_ascend/core/scheduler.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index 41e866c828e..4fafba60d50 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -909,7 +909,7 @@ def test_scheduler_with_pd_transfer(self): model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) first_iter_prefilled_req_num = len(scheduler.running) - self.assertEqual(len(scheduler_output.scheduled_new_reqs), + self.assertEqual(len(scheduler_output.scheduled_new_reqs), scheduler.max_num_running_reqs) self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0) self.assertEqual(len(scheduler_output.finished_req_ids), 0) @@ -919,7 +919,7 @@ def test_scheduler_with_pd_transfer(self): scheduler_output = scheduler.schedule() model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) - self.assertEqual(len(scheduler.finished_prefill_reqs), + self.assertEqual(len(scheduler.finished_prefill_reqs), first_iter_prefilled_req_num) # 3rd iteration, all requests prefilled, change scheduler phase to decode diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 77710ea4aaf..7e5d00123d1 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -59,7 +59,7 @@ def __init__( self.running: list[Request] = [] self.finished_prefill_reqs: deque[Request] = deque() - enable_pd_transfer = getattr(self.scheduler_config, + enable_pd_transfer = getattr(self.scheduler_config, 'enable_pd_transfer', False) self.phase = "" if not enable_pd_transfer else "prefill" From 7095897e054bb17c6d559f0ff462e6c3905d913f Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Fri, 5 Sep 2025 09:05:38 +0800 Subject: [PATCH 05/22] reformat Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/core/test_scheduler.py | 4 ++-- vllm_ascend/core/scheduler.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index 4fafba60d50..b60e14b98fe 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -910,7 +910,7 @@ def test_scheduler_with_pd_transfer(self): scheduler.update_from_output(scheduler_output, model_runner_output) first_iter_prefilled_req_num = len(scheduler.running) self.assertEqual(len(scheduler_output.scheduled_new_reqs), - scheduler.max_num_running_reqs) + scheduler.max_num_running_reqs) self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0) self.assertEqual(len(scheduler_output.finished_req_ids), 0) @@ -920,7 +920,7 @@ def test_scheduler_with_pd_transfer(self): model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) self.assertEqual(len(scheduler.finished_prefill_reqs), - first_iter_prefilled_req_num) + first_iter_prefilled_req_num) # 3rd iteration, all requests prefilled, change scheduler phase to decode scheduler_output = scheduler.schedule() diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 7e5d00123d1..79fac1342ae 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -60,7 +60,7 @@ def __init__( self.finished_prefill_reqs: deque[Request] = deque() enable_pd_transfer = getattr(self.scheduler_config, - 'enable_pd_transfer', False) + 'enable_pd_transfer', False) self.phase = "" if not enable_pd_transfer else "prefill" def schedule(self) -> SchedulerOutput: From e76d6308c17e5a4f76e4223fbb5d9c962785dc16 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Fri, 5 Sep 2025 09:49:48 +0800 Subject: [PATCH 06/22] add decode max num seqs Signed-off-by: CaranLic <740821011@qq.com> --- vllm_ascend/core/schedule_config.py | 2 ++ vllm_ascend/core/scheduler.py | 14 +++++++++++--- vllm_ascend/worker/model_runner_v1.py | 4 +++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index bdeb9d4e7c2..422ca9aa3f7 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -29,6 +29,7 @@ class AscendSchedulerConfig(SchedulerConfig): scheduler_cls: Union[str, Type[object]] = ( "vllm_ascend.core.scheduler.AscendScheduler") enable_pd_transfer: bool = False + decode_max_num_seqs: int = 0 @classmethod def initialize_from_config( @@ -47,6 +48,7 @@ def initialize_from_config( scheduler_config["scheduler_cls"] = ( "vllm_ascend.core.scheduler.AscendScheduler") scheduler_config["enable_pd_transfer"] = False + scheduler_config["decode_max_num_seqs"] = 0 # Override params in original SchedulerConfig with params in ascend_scheduler_config for k, _ in scheduler_config.items(): if hasattr(ascend_scheduler_config, k): diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 79fac1342ae..fb6fc62f7e0 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -61,7 +61,10 @@ def __init__( self.finished_prefill_reqs: deque[Request] = deque() enable_pd_transfer = getattr(self.scheduler_config, 'enable_pd_transfer', False) + decode_max_num_seqs = getattr(self.scheduler_config, + 'decode_max_num_seqs', 0) self.phase = "" if not enable_pd_transfer else "prefill" + self.decode_max_num_running_reqs = self.max_num_running_reqs if decode_max_num_seqs == 0 else decode_max_num_seqs def schedule(self) -> SchedulerOutput: if self.scheduler_config.chunked_prefill_enabled: @@ -105,7 +108,10 @@ def schedule(self) -> SchedulerOutput: # Schedule prefill requests first. while self.waiting and token_budget > 0: - if len(self.running) == self.max_num_running_reqs: + if len(self.running) == (self.decode_max_num_seqs + if self.phase == "decode" else + self.max_num_running_reqs): + break request = self.waiting[0] @@ -268,7 +274,7 @@ def skip_cur_request(): if self.phase == "decode": while len( self.running - ) < self.max_num_running_reqs and self.finished_prefill_reqs: + ) < self.decode_max_num_running_reqs and self.finished_prefill_reqs: request = self.finished_prefill_reqs.popleft() self.running.append(request) @@ -375,7 +381,9 @@ def skip_cur_request(): total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 - assert len(self.running) <= self.max_num_running_reqs + assert len( + self.running + ) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( scheduled_running_reqs) <= len(self.running) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3068d36d0dc..f124842bc9d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -177,7 +177,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, self.block_size) self.max_num_tokens = self.scheduler_config.max_num_batched_tokens - self.max_num_reqs = self.scheduler_config.max_num_seqs + decode_max_num_seqs = getattr(self.scheduler_config, + 'decode_max_num_seqs', 0) + self.max_num_reqs = self.scheduler_config.max_num_seqs if decode_max_num_seqs == 0 else decode_max_num_seqs self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device From 6384861b64b027cc044a231f1ef5de8322748d0b Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Fri, 5 Sep 2025 10:29:21 +0800 Subject: [PATCH 07/22] fix bs Signed-off-by: CaranLic <740821011@qq.com> --- vllm_ascend/core/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index fb6fc62f7e0..01492ab7981 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -108,7 +108,7 @@ def schedule(self) -> SchedulerOutput: # Schedule prefill requests first. while self.waiting and token_budget > 0: - if len(self.running) == (self.decode_max_num_seqs + if len(self.running) == (self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs): From e63792d145e46318bce9dfd25d0cd4d40fb1a6d0 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Fri, 5 Sep 2025 11:41:12 +0800 Subject: [PATCH 08/22] fix bs Signed-off-by: CaranLic <740821011@qq.com> --- vllm_ascend/core/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 01492ab7981..3b52226ad36 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -383,7 +383,7 @@ def skip_cur_request(): assert token_budget >= 0 assert len( self.running - ) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs + ) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( scheduled_running_reqs) <= len(self.running) From e22760b5606bb77b560b4d1dfaa8be90f0a7db50 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Fri, 5 Sep 2025 13:08:46 +0800 Subject: [PATCH 09/22] fix bs Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/core/test_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index b60e14b98fe..330115cc45d 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -903,7 +903,7 @@ def test_scheduler_with_pd_transfer(self): requests = create_requests(num_requests=32) for request in requests: scheduler.add_request(request) - + # 1st iteration, move 16 requests from waiting to running for prefill scheduler_output = scheduler.schedule() model_runner_output = make_output(scheduler) From 47b7a97a841cb6f174e5ea18253247acf4b686f7 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Fri, 5 Sep 2025 14:24:29 +0800 Subject: [PATCH 10/22] add MinPLogitsProcessor reinit Signed-off-by: CaranLic <740821011@qq.com> --- vllm_ascend/ops/__init__.py | 1 + vllm_ascend/ops/min_p_logits_processor.py | 35 +++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 vllm_ascend/ops/min_p_logits_processor.py diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index a1e7417b072..acf97fa1e46 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -20,6 +20,7 @@ import vllm_ascend.ops.common_fused_moe # noqa import vllm_ascend.ops.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa +import vllm_ascend.ops.min_p_logits_processor # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.rotary_embedding import ( diff --git a/vllm_ascend/ops/min_p_logits_processor.py b/vllm_ascend/ops/min_p_logits_processor.py new file mode 100644 index 00000000000..54211cf0412 --- /dev/null +++ b/vllm_ascend/ops/min_p_logits_processor.py @@ -0,0 +1,35 @@ +import torch +from vllm.config import get_current_vllm_config +from vllm.v1.sample.logits_processor impot MinPLogitsProcessor + +original_min_p_logits_processor_init_func = MinPLogitsProcessor.__init__ + +def min_p_logits_processor_init_func(self, *args, **kwargs): + original_min_p_logits_processor_init_func(self, *args, **kwargs) + + vllml_config = get_current_vllm_config() + decode_max_num_seqs = getattr(vllml_config.scheduler_config, + 'decode_max_num_seqs', 0) + # reinit MinPLogitsProcessor if decode_max_num_seqs configured + if decode_max_num_seqs != 0: + device = args[1] + is_pin_memory = args[2] + max_num_reqs = decode_max_num_seqs + + self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=is_pin_memory) + self.min_p_cpu = self.min_p_cpu_tensor.numpy() + + self.use_double_tensor = torch.device(device).type != "cpu" + + if self.use_double_tensor: + self.min_p_device = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + else: + self.min_p_device = self.min_p_cpu_tensor + self.min_p = self.min_p_device[:0] + +MinPLogitsProcessor.__init__ = min_p_logits_processor_init_func From 6dbc1728b422550bd9b81f8aa469fc3daa47e5cf Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Fri, 5 Sep 2025 14:50:35 +0800 Subject: [PATCH 11/22] change decode bs Signed-off-by: CaranLic <740821011@qq.com> --- vllm_ascend/core/scheduler.py | 2 +- vllm_ascend/ops/min_p_logits_processor.py | 8 ++++---- vllm_ascend/worker/model_runner_v1.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 3b52226ad36..85361c80e9b 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -64,7 +64,7 @@ def __init__( decode_max_num_seqs = getattr(self.scheduler_config, 'decode_max_num_seqs', 0) self.phase = "" if not enable_pd_transfer else "prefill" - self.decode_max_num_running_reqs = self.max_num_running_reqs if decode_max_num_seqs == 0 else decode_max_num_seqs + self.decode_max_num_running_reqs = max(self.max_num_running_reqs, decode_max_num_seqs) def schedule(self) -> SchedulerOutput: if self.scheduler_config.chunked_prefill_enabled: diff --git a/vllm_ascend/ops/min_p_logits_processor.py b/vllm_ascend/ops/min_p_logits_processor.py index 54211cf0412..d12c1d92dea 100644 --- a/vllm_ascend/ops/min_p_logits_processor.py +++ b/vllm_ascend/ops/min_p_logits_processor.py @@ -1,20 +1,20 @@ import torch from vllm.config import get_current_vllm_config -from vllm.v1.sample.logits_processor impot MinPLogitsProcessor +from vllm.v1.sample.logits_processor import MinPLogitsProcessor original_min_p_logits_processor_init_func = MinPLogitsProcessor.__init__ def min_p_logits_processor_init_func(self, *args, **kwargs): original_min_p_logits_processor_init_func(self, *args, **kwargs) - vllml_config = get_current_vllm_config() - decode_max_num_seqs = getattr(vllml_config.scheduler_config, + vllm_config = get_current_vllm_config() + decode_max_num_seqs = getattr(vllm_config.scheduler_config, 'decode_max_num_seqs', 0) # reinit MinPLogitsProcessor if decode_max_num_seqs configured if decode_max_num_seqs != 0: device = args[1] is_pin_memory = args[2] - max_num_reqs = decode_max_num_seqs + max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs, decode_max_num_seqs) self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), dtype=torch.float32, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f124842bc9d..4084a06180c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -179,7 +179,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.max_num_tokens = self.scheduler_config.max_num_batched_tokens decode_max_num_seqs = getattr(self.scheduler_config, 'decode_max_num_seqs', 0) - self.max_num_reqs = self.scheduler_config.max_num_seqs if decode_max_num_seqs == 0 else decode_max_num_seqs + self.max_num_reqs = max(self.scheduler_config.max_num_seqs, decode_max_num_seqs) self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device From 5a24294973976eb795508e78cb7449ff5614b095 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Fri, 5 Sep 2025 15:17:37 +0800 Subject: [PATCH 12/22] fix codecheck Signed-off-by: CaranLic <740821011@qq.com> --- vllm_ascend/core/scheduler.py | 3 ++- vllm_ascend/ops/min_p_logits_processor.py | 3 ++- vllm_ascend/worker/model_runner_v1.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 85361c80e9b..965578155d5 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -64,7 +64,8 @@ def __init__( decode_max_num_seqs = getattr(self.scheduler_config, 'decode_max_num_seqs', 0) self.phase = "" if not enable_pd_transfer else "prefill" - self.decode_max_num_running_reqs = max(self.max_num_running_reqs, decode_max_num_seqs) + self.decode_max_num_running_reqs = max(self.max_num_running_reqs, + decode_max_num_seqs) def schedule(self) -> SchedulerOutput: if self.scheduler_config.chunked_prefill_enabled: diff --git a/vllm_ascend/ops/min_p_logits_processor.py b/vllm_ascend/ops/min_p_logits_processor.py index d12c1d92dea..2081e077c0d 100644 --- a/vllm_ascend/ops/min_p_logits_processor.py +++ b/vllm_ascend/ops/min_p_logits_processor.py @@ -14,7 +14,8 @@ def min_p_logits_processor_init_func(self, *args, **kwargs): if decode_max_num_seqs != 0: device = args[1] is_pin_memory = args[2] - max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs, decode_max_num_seqs) + max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs, + decode_max_num_seqs) self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), dtype=torch.float32, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4084a06180c..529ad13d717 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -179,7 +179,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.max_num_tokens = self.scheduler_config.max_num_batched_tokens decode_max_num_seqs = getattr(self.scheduler_config, 'decode_max_num_seqs', 0) - self.max_num_reqs = max(self.scheduler_config.max_num_seqs, decode_max_num_seqs) + self.max_num_reqs = max(self.scheduler_config.max_num_seqs, + decode_max_num_seqs) self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device From 225f0bdd31cca960a6804ffd37234018dd3ac3f5 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Fri, 5 Sep 2025 16:23:30 +0800 Subject: [PATCH 13/22] fix codecheck Signed-off-by: CaranLic <740821011@qq.com> --- vllm_ascend/ops/min_p_logits_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_ascend/ops/min_p_logits_processor.py b/vllm_ascend/ops/min_p_logits_processor.py index 2081e077c0d..3056aa50719 100644 --- a/vllm_ascend/ops/min_p_logits_processor.py +++ b/vllm_ascend/ops/min_p_logits_processor.py @@ -4,6 +4,7 @@ original_min_p_logits_processor_init_func = MinPLogitsProcessor.__init__ + def min_p_logits_processor_init_func(self, *args, **kwargs): original_min_p_logits_processor_init_func(self, *args, **kwargs) @@ -33,4 +34,5 @@ def min_p_logits_processor_init_func(self, *args, **kwargs): self.min_p_device = self.min_p_cpu_tensor self.min_p = self.min_p_device[:0] + MinPLogitsProcessor.__init__ = min_p_logits_processor_init_func From 87161a035b452f8f98e4f20dc794ad3296bcbb3f Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Mon, 8 Sep 2025 09:47:32 +0800 Subject: [PATCH 14/22] add doc for new config, add test case for config Signed-off-by: CaranLic <740821011@qq.com> --- docs/source/user_guide/configuration/additional_config.md | 2 ++ tests/ut/core/test_schedule_config.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index c67f340815e..a8403acaa16 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -57,6 +57,8 @@ The details of each config option are as follows: | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | | `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine| +| `enable_pd_transfer` | bool | `False` | Whether to enable pd transfer. When using it, decode is started only when prefill of all requests is done. This option only takes effects on offline inference. | +| `decode_max_num_seqs` | int | `0` | Whether to change max_num_seqs of decode phase when enable pd transfer. This option only takes effects when enable_pd_transfer is True. | ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well. diff --git a/tests/ut/core/test_schedule_config.py b/tests/ut/core/test_schedule_config.py index fcd790b691d..8f1422f1b8c 100644 --- a/tests/ut/core/test_schedule_config.py +++ b/tests/ut/core/test_schedule_config.py @@ -171,8 +171,10 @@ def test_initialize_from_config_with_pd_transfer(self): self.basic_scheduler_config, AscendSchedulerConfig( enable_pd_transfer=True, + decode_max_num_seqs=48, max_num_batched_tokens=4096, max_model_len=4096, ), ) self.assertEqual(ascend_config.enable_pd_transfer, True) + self.assertEqual(ascend_config.decode_max_num_seqs, 48) From b6952b88af5a6551466c4802ff6a8d2af17d9997 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Mon, 8 Sep 2025 18:01:09 +0800 Subject: [PATCH 15/22] add test for patch processor Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/ops/test_min_p_logits_processor.py | 60 +++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/ut/ops/test_min_p_logits_processor.py diff --git a/tests/ut/ops/test_min_p_logits_processor.py b/tests/ut/ops/test_min_p_logits_processor.py new file mode 100644 index 00000000000..579a363348b --- /dev/null +++ b/tests/ut/ops/test_min_p_logits_processor.py @@ -0,0 +1,60 @@ +from pytest_mock import MockerFixture +from vllm.config import VllmConfig, SchedulerConfig +from vllm.v1.sample.logits_processor import MinPLogitsProcessor + +from tests.ut.base import PytestBase +from vllm_ascend.ops.min_p_logits_processor import min_p_logits_processor_init_func + + +class TestMinPLogitsProcessorInitFunc(PytestBase): + + def test_init_func_without_decode_max_num_seqs(self, + mocker: MockerFixture): + mock_min_p_logits_processor = mocker.MagicMock( + spec=MinPLogitsProcessor) + + min_p_logits_processor_init_func(mock_min_p_logits_processor, + VllmConfig(), "cpu:0", True) + + assert mock_min_p_logits_processor.min_p_cpu is not None + assert mock_min_p_logits_processor.min_p_device is not None + assert mock_min_p_logits_processor.min_p_cpu.shape[0] == 128 + + def test_init_func_with_decode_max_num_seqs_and_npu( + self, mocker: MockerFixture): + mock_min_p_logits_processor = mocker.MagicMock( + spec=MinPLogitsProcessor) + + mock_vllm_config = mocker.MagicMock(spec=VllmConfig) + mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) + mock_scheduler_config.decode_max_num_seqs = 256 + mock_scheduler_config.max_num_seqs = 128 + mock_vllm_config.scheduler_config = mock_scheduler_config + mocker.patch( + "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", + return_value=mock_vllm_config) + + min_p_logits_processor_init_func(mock_min_p_logits_processor, + mock_vllm_config, "npu:0", True) + + assert mock_min_p_logits_processor.min_p_cpu.shape[0] == 256 + assert mock_min_p_logits_processor.use_double_tensor == True + + def test_init_func_with_decode_max_num_seqs_and_cpu( + self, mocker: MockerFixture): + mock_min_p_logits_processor = mocker.MagicMock( + spec=MinPLogitsProcessor) + + mock_vllm_config = mocker.MagicMock(spec=VllmConfig) + mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) + mock_scheduler_config.max_num_seqs = 128 + mock_scheduler_config.decode_max_num_seqs = 256 + mock_vllm_config.scheduler_config = mock_scheduler_config + mocker.patch( + "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", + return_value=mock_vllm_config) + + min_p_logits_processor_init_func(mock_min_p_logits_processor, + mock_vllm_config, "cpu:0", True) + + assert mock_min_p_logits_processor.use_double_tensor == False From c74d4c6359484a41734dc3ed78a03de886a3cc92 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Mon, 8 Sep 2025 18:48:49 +0800 Subject: [PATCH 16/22] reformat Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/ops/test_min_p_logits_processor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/ut/ops/test_min_p_logits_processor.py b/tests/ut/ops/test_min_p_logits_processor.py index 579a363348b..75a861b8e4b 100644 --- a/tests/ut/ops/test_min_p_logits_processor.py +++ b/tests/ut/ops/test_min_p_logits_processor.py @@ -1,9 +1,10 @@ from pytest_mock import MockerFixture -from vllm.config import VllmConfig, SchedulerConfig +from vllm.config import SchedulerConfig, VllmConfig from vllm.v1.sample.logits_processor import MinPLogitsProcessor from tests.ut.base import PytestBase -from vllm_ascend.ops.min_p_logits_processor import min_p_logits_processor_init_func +from vllm_ascend.ops.min_p_logits_processor import \ + min_p_logits_processor_init_func class TestMinPLogitsProcessorInitFunc(PytestBase): From f3b3218ca49a5f58ae0f9450426a6e7b62a4355e Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Mon, 8 Sep 2025 18:53:37 +0800 Subject: [PATCH 17/22] reformat Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/ops/test_min_p_logits_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ut/ops/test_min_p_logits_processor.py b/tests/ut/ops/test_min_p_logits_processor.py index 75a861b8e4b..cf161eb9c2e 100644 --- a/tests/ut/ops/test_min_p_logits_processor.py +++ b/tests/ut/ops/test_min_p_logits_processor.py @@ -39,7 +39,7 @@ def test_init_func_with_decode_max_num_seqs_and_npu( mock_vllm_config, "npu:0", True) assert mock_min_p_logits_processor.min_p_cpu.shape[0] == 256 - assert mock_min_p_logits_processor.use_double_tensor == True + assert mock_min_p_logits_processor.use_double_tensor is True def test_init_func_with_decode_max_num_seqs_and_cpu( self, mocker: MockerFixture): @@ -58,4 +58,4 @@ def test_init_func_with_decode_max_num_seqs_and_cpu( min_p_logits_processor_init_func(mock_min_p_logits_processor, mock_vllm_config, "cpu:0", True) - assert mock_min_p_logits_processor.use_double_tensor == False + assert mock_min_p_logits_processor.use_double_tensor is False From 396c81f5ccf28736bdb8301a140035f6e8501485 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Mon, 8 Sep 2025 19:35:07 +0800 Subject: [PATCH 18/22] fix test error for unexpected online ut Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/ops/test_min_p_logits_processor.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/ut/ops/test_min_p_logits_processor.py b/tests/ut/ops/test_min_p_logits_processor.py index cf161eb9c2e..dff3b03e22a 100644 --- a/tests/ut/ops/test_min_p_logits_processor.py +++ b/tests/ut/ops/test_min_p_logits_processor.py @@ -13,9 +13,17 @@ def test_init_func_without_decode_max_num_seqs(self, mocker: MockerFixture): mock_min_p_logits_processor = mocker.MagicMock( spec=MinPLogitsProcessor) + mock_vllm_config = mocker.MagicMock(spec=VllmConfig) + mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) + mock_scheduler_config.decode_max_num_seqs = 0 + mock_scheduler_config.max_num_seqs = 128 + mock_vllm_config.scheduler_config = mock_scheduler_config + mocker.patch( + "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", + return_value=mock_vllm_config) min_p_logits_processor_init_func(mock_min_p_logits_processor, - VllmConfig(), "cpu:0", True) + mock_vllm_config, "cpu:0", True) assert mock_min_p_logits_processor.min_p_cpu is not None assert mock_min_p_logits_processor.min_p_device is not None From cb13556e3da578c12a0abbb4df396a2f8b5bd55e Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Mon, 8 Sep 2025 20:47:52 +0800 Subject: [PATCH 19/22] fix test error for online ut get device count error Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/ops/test_min_p_logits_processor.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/ut/ops/test_min_p_logits_processor.py b/tests/ut/ops/test_min_p_logits_processor.py index dff3b03e22a..2993f900bf3 100644 --- a/tests/ut/ops/test_min_p_logits_processor.py +++ b/tests/ut/ops/test_min_p_logits_processor.py @@ -13,6 +13,8 @@ def test_init_func_without_decode_max_num_seqs(self, mocker: MockerFixture): mock_min_p_logits_processor = mocker.MagicMock( spec=MinPLogitsProcessor) + mock_min_p_logits_processor.min_p_cpu = None + mock_vllm_config = mocker.MagicMock(spec=VllmConfig) mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) mock_scheduler_config.decode_max_num_seqs = 0 @@ -21,13 +23,14 @@ def test_init_func_without_decode_max_num_seqs(self, mocker.patch( "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", return_value=mock_vllm_config) + mocker.patch( + "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", + return_value=None) min_p_logits_processor_init_func(mock_min_p_logits_processor, - mock_vllm_config, "cpu:0", True) + mock_vllm_config, "cpu", True) - assert mock_min_p_logits_processor.min_p_cpu is not None - assert mock_min_p_logits_processor.min_p_device is not None - assert mock_min_p_logits_processor.min_p_cpu.shape[0] == 128 + assert mock_min_p_logits_processor.min_p_cpu is None def test_init_func_with_decode_max_num_seqs_and_npu( self, mocker: MockerFixture): @@ -42,9 +45,12 @@ def test_init_func_with_decode_max_num_seqs_and_npu( mocker.patch( "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", return_value=mock_vllm_config) + mocker.patch( + "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", + return_value=None) min_p_logits_processor_init_func(mock_min_p_logits_processor, - mock_vllm_config, "npu:0", True) + mock_vllm_config, "npu", True) assert mock_min_p_logits_processor.min_p_cpu.shape[0] == 256 assert mock_min_p_logits_processor.use_double_tensor is True @@ -62,8 +68,11 @@ def test_init_func_with_decode_max_num_seqs_and_cpu( mocker.patch( "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", return_value=mock_vllm_config) + mocker.patch( + "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", + return_value=None) min_p_logits_processor_init_func(mock_min_p_logits_processor, - mock_vllm_config, "cpu:0", True) + mock_vllm_config, "cpu", True) assert mock_min_p_logits_processor.use_double_tensor is False From 2bf91dffe40379c9ff27123fc3b1abee372cd228 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Mon, 8 Sep 2025 21:12:56 +0800 Subject: [PATCH 20/22] fix test error for online ut torch function error Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/ops/test_min_p_logits_processor.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/ut/ops/test_min_p_logits_processor.py b/tests/ut/ops/test_min_p_logits_processor.py index 2993f900bf3..0ff4f2a74ac 100644 --- a/tests/ut/ops/test_min_p_logits_processor.py +++ b/tests/ut/ops/test_min_p_logits_processor.py @@ -48,9 +48,16 @@ def test_init_func_with_decode_max_num_seqs_and_npu( mocker.patch( "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", return_value=None) + # torch.zeros/torch.empty returns error on online ut machine, so mock it + mock_tensor = torch.zeros((256, ), + dtype=torch.float32, + pin_memory=False) + mocker.patch("torch.zeros", return_value=mock_tensor) + mock_empty_tensor = torch.empty((256, ), dtype=torch.float32) + mocker.patch("torch.empty", return_value=mock_empty_tensor) min_p_logits_processor_init_func(mock_min_p_logits_processor, - mock_vllm_config, "npu", True) + mock_vllm_config, "npu", False) assert mock_min_p_logits_processor.min_p_cpu.shape[0] == 256 assert mock_min_p_logits_processor.use_double_tensor is True @@ -71,8 +78,13 @@ def test_init_func_with_decode_max_num_seqs_and_cpu( mocker.patch( "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", return_value=None) + # torch.zeros returns error on online ut machine, so mock it + mock_tensor = torch.zeros((256, ), + dtype=torch.float32, + pin_memory=False) + mocker.patch("torch.zeros", return_value=mock_tensor) min_p_logits_processor_init_func(mock_min_p_logits_processor, - mock_vllm_config, "cpu", True) + mock_vllm_config, "cpu", False) assert mock_min_p_logits_processor.use_double_tensor is False From b5e042545077142ab09e38a4822e486dae1f5e36 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Mon, 8 Sep 2025 21:18:17 +0800 Subject: [PATCH 21/22] fix test error for online ut torch function error Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/ops/test_min_p_logits_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ut/ops/test_min_p_logits_processor.py b/tests/ut/ops/test_min_p_logits_processor.py index 0ff4f2a74ac..7141d3e05b3 100644 --- a/tests/ut/ops/test_min_p_logits_processor.py +++ b/tests/ut/ops/test_min_p_logits_processor.py @@ -1,3 +1,4 @@ +import torch from pytest_mock import MockerFixture from vllm.config import SchedulerConfig, VllmConfig from vllm.v1.sample.logits_processor import MinPLogitsProcessor From 30571f1a6fdd3e3ca84893eb6e0eaa169574d154 Mon Sep 17 00:00:00 2001 From: CaranLic <740821011@qq.com> Date: Tue, 9 Sep 2025 17:17:16 +0800 Subject: [PATCH 22/22] change from patch MinPLogitsProcessor to redefine build_logitsprocs Signed-off-by: CaranLic <740821011@qq.com> --- tests/ut/ops/test_min_p_logits_processor.py | 91 ------------------- .../sample/logits_processor/test_builtin.py | 40 ++++++++ vllm_ascend/ops/__init__.py | 1 - vllm_ascend/ops/min_p_logits_processor.py | 38 -------- .../sample/logits_processor/__init__.py | 50 ++++++++++ .../sample/logits_processor/builtin.py | 35 +++++++ vllm_ascend/worker/model_runner_v1.py | 2 +- 7 files changed, 126 insertions(+), 131 deletions(-) delete mode 100644 tests/ut/ops/test_min_p_logits_processor.py create mode 100644 tests/ut/sample/logits_processor/test_builtin.py delete mode 100644 vllm_ascend/ops/min_p_logits_processor.py create mode 100644 vllm_ascend/sample/logits_processor/__init__.py create mode 100644 vllm_ascend/sample/logits_processor/builtin.py diff --git a/tests/ut/ops/test_min_p_logits_processor.py b/tests/ut/ops/test_min_p_logits_processor.py deleted file mode 100644 index 7141d3e05b3..00000000000 --- a/tests/ut/ops/test_min_p_logits_processor.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -from pytest_mock import MockerFixture -from vllm.config import SchedulerConfig, VllmConfig -from vllm.v1.sample.logits_processor import MinPLogitsProcessor - -from tests.ut.base import PytestBase -from vllm_ascend.ops.min_p_logits_processor import \ - min_p_logits_processor_init_func - - -class TestMinPLogitsProcessorInitFunc(PytestBase): - - def test_init_func_without_decode_max_num_seqs(self, - mocker: MockerFixture): - mock_min_p_logits_processor = mocker.MagicMock( - spec=MinPLogitsProcessor) - mock_min_p_logits_processor.min_p_cpu = None - - mock_vllm_config = mocker.MagicMock(spec=VllmConfig) - mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) - mock_scheduler_config.decode_max_num_seqs = 0 - mock_scheduler_config.max_num_seqs = 128 - mock_vllm_config.scheduler_config = mock_scheduler_config - mocker.patch( - "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", - return_value=mock_vllm_config) - mocker.patch( - "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", - return_value=None) - - min_p_logits_processor_init_func(mock_min_p_logits_processor, - mock_vllm_config, "cpu", True) - - assert mock_min_p_logits_processor.min_p_cpu is None - - def test_init_func_with_decode_max_num_seqs_and_npu( - self, mocker: MockerFixture): - mock_min_p_logits_processor = mocker.MagicMock( - spec=MinPLogitsProcessor) - - mock_vllm_config = mocker.MagicMock(spec=VllmConfig) - mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) - mock_scheduler_config.decode_max_num_seqs = 256 - mock_scheduler_config.max_num_seqs = 128 - mock_vllm_config.scheduler_config = mock_scheduler_config - mocker.patch( - "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", - return_value=mock_vllm_config) - mocker.patch( - "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", - return_value=None) - # torch.zeros/torch.empty returns error on online ut machine, so mock it - mock_tensor = torch.zeros((256, ), - dtype=torch.float32, - pin_memory=False) - mocker.patch("torch.zeros", return_value=mock_tensor) - mock_empty_tensor = torch.empty((256, ), dtype=torch.float32) - mocker.patch("torch.empty", return_value=mock_empty_tensor) - - min_p_logits_processor_init_func(mock_min_p_logits_processor, - mock_vllm_config, "npu", False) - - assert mock_min_p_logits_processor.min_p_cpu.shape[0] == 256 - assert mock_min_p_logits_processor.use_double_tensor is True - - def test_init_func_with_decode_max_num_seqs_and_cpu( - self, mocker: MockerFixture): - mock_min_p_logits_processor = mocker.MagicMock( - spec=MinPLogitsProcessor) - - mock_vllm_config = mocker.MagicMock(spec=VllmConfig) - mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) - mock_scheduler_config.max_num_seqs = 128 - mock_scheduler_config.decode_max_num_seqs = 256 - mock_vllm_config.scheduler_config = mock_scheduler_config - mocker.patch( - "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", - return_value=mock_vllm_config) - mocker.patch( - "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", - return_value=None) - # torch.zeros returns error on online ut machine, so mock it - mock_tensor = torch.zeros((256, ), - dtype=torch.float32, - pin_memory=False) - mocker.patch("torch.zeros", return_value=mock_tensor) - - min_p_logits_processor_init_func(mock_min_p_logits_processor, - mock_vllm_config, "cpu", False) - - assert mock_min_p_logits_processor.use_double_tensor is False diff --git a/tests/ut/sample/logits_processor/test_builtin.py b/tests/ut/sample/logits_processor/test_builtin.py new file mode 100644 index 00000000000..cecd18624d5 --- /dev/null +++ b/tests/ut/sample/logits_processor/test_builtin.py @@ -0,0 +1,40 @@ +import torch +from pytest_mock import MockerFixture +from vllm.config import SchedulerConfig, VllmConfig + +from tests.ut.base import PytestBase +from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor + + +class TestMinPLogitsProcessorInitFunc(PytestBase): + + def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture): + device_cpu = torch.device("cpu") + device_npu = torch.device("npu") + is_pin_memory = False + mock_vllm_config = mocker.MagicMock(spec=VllmConfig) + mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) + mock_scheduler_config.decode_max_num_seqs = 0 + mock_scheduler_config.max_num_seqs = 128 + mock_vllm_config.scheduler_config = mock_scheduler_config + # torch.zeros/torch.empty returns error on online ut machine, so mock it + mock_tensor = torch.zeros((256, ), + dtype=torch.float32, + pin_memory=False) + mocker.patch("torch.zeros", return_value=mock_tensor) + mock_empty_tensor = torch.empty((256, ), dtype=torch.float32) + mocker.patch("torch.empty", return_value=mock_empty_tensor) + + processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu, + is_pin_memory) + + assert processor_cpu.min_p is not None + assert processor_cpu.use_double_tensor is False + assert processor_cpu.min_p_cpu.shape[0] == 256 + + processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu, + is_pin_memory) + + assert processor_cpu.min_p is not None + assert processor_cpu.use_double_tensor is True + assert processor_cpu.min_p_cpu.shape[0] == 256 diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 5317add8cf6..5c8a79847c9 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -20,7 +20,6 @@ import vllm_ascend.ops.common_fused_moe # noqa import vllm_ascend.ops.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa -import vllm_ascend.ops.min_p_logits_processor # noqa import vllm_ascend.ops.register_custom_ops # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul diff --git a/vllm_ascend/ops/min_p_logits_processor.py b/vllm_ascend/ops/min_p_logits_processor.py deleted file mode 100644 index 3056aa50719..00000000000 --- a/vllm_ascend/ops/min_p_logits_processor.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -from vllm.config import get_current_vllm_config -from vllm.v1.sample.logits_processor import MinPLogitsProcessor - -original_min_p_logits_processor_init_func = MinPLogitsProcessor.__init__ - - -def min_p_logits_processor_init_func(self, *args, **kwargs): - original_min_p_logits_processor_init_func(self, *args, **kwargs) - - vllm_config = get_current_vllm_config() - decode_max_num_seqs = getattr(vllm_config.scheduler_config, - 'decode_max_num_seqs', 0) - # reinit MinPLogitsProcessor if decode_max_num_seqs configured - if decode_max_num_seqs != 0: - device = args[1] - is_pin_memory = args[2] - max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs, - decode_max_num_seqs) - - self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=is_pin_memory) - self.min_p_cpu = self.min_p_cpu_tensor.numpy() - - self.use_double_tensor = torch.device(device).type != "cpu" - - if self.use_double_tensor: - self.min_p_device = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - else: - self.min_p_device = self.min_p_cpu_tensor - self.min_p = self.min_p_device[:0] - - -MinPLogitsProcessor.__init__ = min_p_logits_processor_init_func diff --git a/vllm_ascend/sample/logits_processor/__init__.py b/vllm_ascend/sample/logits_processor/__init__.py new file mode 100644 index 00000000000..5f810bfcd12 --- /dev/null +++ b/vllm_ascend/sample/logits_processor/__init__.py @@ -0,0 +1,50 @@ +import itertools +from collections.abc import Sequence +from typing import TYPE_CHECKING, Union + +import torch +from vllm.logger import init_logger +from vllm.v1.sample import logits_processor +from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, + MinTokensLogitsProcessor) +from vllm.v1.sample.logits_processor.interface import LogitsProcessor +from vllm.v1.sample.logits_processor.state import LogitsProcessors + +from vllm_ascend.sample.logits_processor.builtin import \ + AscendMinPLogitsProcessor + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + +# Error message when the user tries to initialize vLLM with a pooling model +# and custom logitsproces +STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" + " logits processors.") + +BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ + MinTokensLogitsProcessor, + LogitBiasLogitsProcessor, + AscendMinPLogitsProcessor, +] + + +def build_logitsprocs( + vllm_config: "VllmConfig", + device: torch.device, + is_pin_memory: bool, + is_pooling_model: bool, + custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), +) -> LogitsProcessors: + if is_pooling_model: + if custom_logitsprocs: + raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) + logger.debug("Skipping logits processor loading because pooling models" + " do not support logits processors.") + return LogitsProcessors() + custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs( + custom_logitsprocs) + return LogitsProcessors( + ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) diff --git a/vllm_ascend/sample/logits_processor/builtin.py b/vllm_ascend/sample/logits_processor/builtin.py new file mode 100644 index 00000000000..f38d940240c --- /dev/null +++ b/vllm_ascend/sample/logits_processor/builtin.py @@ -0,0 +1,35 @@ +import torch +from vllm.config import VllmConfig +from vllm.v1.sample.logits_processor import MinPLogitsProcessor + + +class AscendMinPLogitsProcessor(MinPLogitsProcessor): + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + super().__init__(vllm_config, device, is_pin_memory) + + decode_max_num_seqs = getattr(vllm_config.scheduler_config, + 'decode_max_num_seqs', 0) + if decode_max_num_seqs != 0: + max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs, + decode_max_num_seqs) + + self.min_p_count: int = 0 + + self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=is_pin_memory) + self.min_p_cpu = self.min_p_cpu_tensor.numpy() + + self.use_double_tensor = torch.device(device).type != "cpu" + + if self.use_double_tensor: + # Pre-allocated device tensor + self.min_p_device: torch.Tensor = torch.empty( + (max_num_reqs, ), dtype=torch.float32, device=device) + else: + self.min_p_device = self.min_p_cpu_tensor + # Current slice of the device tensor + self.min_p: torch.Tensor = self.min_p_device[:0] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a19ee6610a8..f33503a3feb 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -66,7 +66,6 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -86,6 +85,7 @@ from vllm_ascend.compilation.acl_graph import ACLGraphWrapper from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform +from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer