From a76c77ed06279a9962d42f1efc3f752051456aea Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Mon, 1 Jun 2026 15:26:42 +0000 Subject: [PATCH 1/3] Fix NixlConnector PD + Spec Decode acceptance (2 GPUs) issue Signed-off-by: yewentao256 --- vllm/v1/core/sched/scheduler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 73d3dcb4b65e..c9b261abdf47 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -707,9 +707,10 @@ def schedule(self) -> SchedulerOutput: # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. - limit_lookahead_tokens = load_kv_async and self.use_eagle effective_lookahead_tokens = ( - 0 if limit_lookahead_tokens else self.num_lookahead_tokens + 0 + if self.use_eagle and request.num_computed_tokens == 0 + else self.num_lookahead_tokens ) # Determine if we need to allocate cross-attention blocks. From 39494b00daeb61b916aaf96330aee62041c263b1 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Mon, 1 Jun 2026 16:48:26 +0000 Subject: [PATCH 2/3] fix dflash Signed-off-by: yewentao256 --- vllm/v1/core/sched/scheduler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c9b261abdf47..f4e6a9ee68ec 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -210,6 +210,7 @@ def __init__( speculative_config = vllm_config.speculative_config self.use_eagle = False + self.use_dflash = False self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens @@ -219,6 +220,7 @@ def __init__( if speculative_config.uses_draft_model(): self.num_lookahead_tokens = self.num_spec_tokens if speculative_config.use_dflash(): + self.use_dflash = True # DFlash requires an extra lookahead slot since it uses in-fill-style # decoding instead of standard next-token sampling, so it has a query # for the last sampled token plus queries for each draft token. @@ -709,7 +711,11 @@ def schedule(self) -> SchedulerOutput: # of local and remote blocks. effective_lookahead_tokens = ( 0 - if self.use_eagle and request.num_computed_tokens == 0 + if ( + self.use_eagle + and not self.use_dflash + and request.num_computed_tokens == 0 + ) else self.num_lookahead_tokens ) From 61b3b6bbed3fba95a07d91562ca333666cfa40e6 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 3 Jun 2026 18:14:14 +0000 Subject: [PATCH 3/3] update to prefill node only Signed-off-by: yewentao256 --- vllm/v1/core/sched/scheduler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1f67e5ee9342..40d7fc8b3368 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -733,12 +733,17 @@ def schedule(self) -> SchedulerOutput: # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. + is_pd_prefill_producer = ( + request.num_computed_tokens == 0 + and request.kv_transfer_params is not None + and request.kv_transfer_params.get("do_remote_decode", False) + ) effective_lookahead_tokens = ( 0 if ( self.use_eagle and not self.use_dflash - and request.num_computed_tokens == 0 + and (load_kv_async or is_pd_prefill_producer) ) else self.num_lookahead_tokens )