From f7f6bb6994e4fb84ce3709d9828486e94d756d46 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 1 Jul 2025 10:34:29 -0700 Subject: [PATCH 01/10] =?UTF-8?q?=F0=9F=90=9B=20req=5Fids=20is=20now=20a?= =?UTF-8?q?=20list?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 48bfe0d58..c1493a5ab 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -516,7 +516,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, # Set up dummy cached_requests for decode steps cached_requests = [ CachedRequestData( - req_id=req.req_id, + req_ids=[req.req_id], resumed_from_preemption=False, new_token_ids=[ valid_token_ids_tensor[torch.randint( From 8f7acaf8a664520aa1e5208e3914cf835b540cc5 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 1 Jul 2025 10:40:15 -0700 Subject: [PATCH 02/10] =?UTF-8?q?=F0=9F=90=9B=20req=5Fids=20is=20now=20a?= =?UTF-8?q?=20list?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index c1493a5ab..41f25aafd 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -361,7 +361,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): # one decode iteration across both sequences cached_requests = [ CachedRequestData( - req_id=req.req_id, + req_ids=[req.req_id], resumed_from_preemption=False, new_token_ids=[ valid_token_ids_tensor[torch.randint( From 06b8d7d85b72205c009cd8f347a702e052257357 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 1 Jul 2025 12:35:07 -0700 Subject: [PATCH 03/10] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20using=20cached=20req?= =?UTF-8?q?uests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_model_runner.py | 28 +++++++------ vllm_spyre/v1/worker/spyre_worker.py | 48 ++++++++++++++-------- 2 files changed, 47 insertions(+), 29 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 17abe4129..fd593b0bc 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -241,32 +241,35 @@ def update_states(self, scheduler_output: SchedulerOutput): # # NOTE: req_state.output_token_ids is being mutated. - for req_data in scheduler_output.scheduled_cached_reqs: - req_id = req_data.req_id + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] + # for req_data in scheduler_output.scheduled_cached_reqs: + # req_id = req_data.req_ids[0] + # req_state = self.requests[req_id] + # Update the cached states. - num_computed_tokens = req_data.num_computed_tokens + num_computed_tokens = req_data.num_computed_tokens[i] + new_token_ids = req_data.new_token_ids[i] # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec decode tokens. - num_new_tokens = (num_computed_tokens + - len(req_data.new_token_ids) - + num_new_tokens = (num_computed_tokens + len(new_token_ids) - req_state.num_tokens) if num_new_tokens == 1: # Avoid slicing list in most common case. - req_state.output_token_ids.append(req_data.new_token_ids[-1]) + req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: req_state.output_token_ids.extend( - req_data.new_token_ids[-num_new_tokens:]) + new_token_ids[-num_new_tokens:]) req_index = self.input_batch.get_req_index(req_id) # Add new_token_ids to token_ids_cpu. # TODO: Update for spec decoding in the future start_token_index = num_computed_tokens - end_token_index = num_computed_tokens + len(req_data.new_token_ids) + end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = req_data.new_token_ids + req_index, start_token_index:end_token_index] = new_token_ids if scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids: @@ -277,8 +280,7 @@ def update_states(self, scheduler_output: SchedulerOutput): def _prepare_prompt(self, _: list[NewRequestData]) -> ModelForwardInputs: raise NotImplementedError - def _prepare_decode(self, - _: list[CachedRequestData]) -> ModelForwardInputs: + def _prepare_decode(self, _: CachedRequestData) -> ModelForwardInputs: raise NotImplementedError def prepare_model_input( @@ -291,7 +293,7 @@ def prepare_model_input( # Prepare input tensors. if is_prompt: # Assert no running requests - assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs.req_ids) == 0 return self._prepare_prompt(scheduler_output.scheduled_new_reqs) else: diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 41f25aafd..2ef0698b1 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -344,7 +344,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): for i, req in enumerate(dummy_requests): scheduler_output = SchedulerOutput( scheduled_new_reqs=[req], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={req.req_id: prompt_len}, total_num_scheduled_tokens=prompt_len, scheduled_spec_decode_tokens={}, @@ -359,22 +359,38 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): self.execute_model(scheduler_output) # one decode iteration across both sequences - cached_requests = [ - CachedRequestData( - req_ids=[req.req_id], - resumed_from_preemption=False, - new_token_ids=[ - valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (1, )).item()] - ], # placeholder token - new_block_ids=req.block_ids, - num_computed_tokens=prompt_len, - ) for req in dummy_requests - ] + # cached_requests = [ + # CachedRequestData( + # req_ids=[req.req_id], + # resumed_from_preemption=False, + # new_token_ids=[ + # valid_token_ids_tensor[torch.randint( + # 0, len(valid_token_ids_tensor), (1, )).item()] + # ], # placeholder token + # new_block_ids=req.block_ids, + # num_computed_tokens=prompt_len, + # ) for req in dummy_requests + # ] + req_ids = [] + new_token_ids = [] + new_block_ids = [] + for req in dummy_requests: + req_ids.append(req.req_id) + new_token_ids.append(valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), + (1, )).item()]), # placeholder token + new_block_ids.append(req.block_ids), + cached_request_data = CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=False, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=[prompt_len], + ) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=cached_requests, + scheduled_cached_reqs=cached_request_data, num_scheduled_tokens={f"warmup-{i}": 1 for i in range(batch_size)}, total_num_scheduled_tokens=batch_size, @@ -393,7 +409,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): # Needed to clean up the data of model runner scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, # NOTE: this means no work to do total_num_scheduled_tokens=0, @@ -530,7 +546,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, # Set up scheduler_output for execute_model scheduler_output = SchedulerOutput( scheduled_new_reqs=dummy_requests, - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={i: prompt_len for i in range(batch_size)}, total_num_scheduled_tokens=sum(prompt_len From a064d3b9284664d2076135bcd16ca0f5bb9558c1 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 2 Jul 2025 11:55:03 -0700 Subject: [PATCH 04/10] =?UTF-8?q?=F0=9F=90=9B=20first=20pass=20for=20sb?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_model_runner.py | 24 ++++---- vllm_spyre/v1/worker/spyre_worker.py | 68 ++++++++++++++-------- 2 files changed, 59 insertions(+), 33 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index fd593b0bc..063d4133d 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -457,19 +457,22 @@ def _prepare_prompt( def _prepare_decode( self, - cached_requests: list[CachedRequestData], + cached_requests: CachedRequestData, ) -> ModelForwardInputs: - assert len(cached_requests) > 0 + assert len(cached_requests.req_ids) > 0 input_tokens: list[list[int]] = [ [0] for _ in range(self._position_ids.shape[0]) ] - for cached_request in cached_requests: + for i, req_id in enumerate(cached_requests.req_ids): + # for cached_request in cached_requests: # TODO: Will this always just be one token ID if there's no spec # or jump decoding? - generation_token = cached_request.new_token_ids[-1] - input_tokens[self.input_batch.req_id_to_index[ - cached_request.req_id]] = [generation_token] + new_token_ids = cached_requests.new_token_ids[i] + generation_token = new_token_ids[-1] + input_tokens[self.input_batch.req_id_to_index[req_id]] = [ + generation_token + ] # update position ids and attention mask self._update_position_ids() @@ -754,20 +757,21 @@ def _prepare_prompt( def _prepare_decode( self, - cached_requests: list[CachedRequestData], + cached_requests: CachedRequestData, ) -> ModelForwardInputs: - assert len(cached_requests) > 0 + assert len(cached_requests.req_ids) > 0 input_tokens = [] input_positions = [] block_table = [] slot_mapping = [] left_padded_prompt_mask = [] - self.model.indices = torch.ones(len(cached_requests), + self.model.indices = torch.ones(len(cached_requests.req_ids), dtype=torch.bool, device="cpu") - assert len(self.input_batch.req_id_to_index) == len(cached_requests) + assert len(self.input_batch.req_id_to_index) == len( + cached_requests.req_ids) # TODO(wallas): I think we can do better here, without sorting or # creating an intermediary dictionary cached_reqs_map = {c.req_id: c for c in cached_requests} diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 2ef0698b1..e0d210a37 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -372,14 +372,15 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): # ) for req in dummy_requests # ] req_ids = [] - new_token_ids = [] - new_block_ids = [] + new_token_ids: list[list[int]] = [] + new_block_ids: list[tuple[list[int], ...]] = [] for req in dummy_requests: req_ids.append(req.req_id) - new_token_ids.append(valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), - (1, )).item()]), # placeholder token - new_block_ids.append(req.block_ids), + new_token_ids.append([ + valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), (1, )).item()] + ]), # placeholder token + new_block_ids.append([req.block_ids]), cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, @@ -530,23 +531,43 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, ] # Set up dummy cached_requests for decode steps - cached_requests = [ - CachedRequestData( - req_ids=[req.req_id], - resumed_from_preemption=False, - new_token_ids=[ - valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (1, )).item()] - ], # placeholder token - new_block_ids=req.block_ids, - num_computed_tokens=req.num_computed_tokens, - ) for req in dummy_requests - ] + # cached_requests = [ + # CachedRequestData( + # req_ids=[req.req_id], + # resumed_from_preemption=False, + # new_token_ids=[ + # valid_token_ids_tensor[torch.randint( + # 0, len(valid_token_ids_tensor), (1, )).item()] + # ], # placeholder token + # new_block_ids=req.block_ids, + # num_computed_tokens=req.num_computed_tokens, + # ) for req in dummy_requests + # ] + req_ids = [] + new_token_ids: list[list[int]] = [] + new_block_ids: list[tuple[list[int], ...]] = [] + num_computed_tokens = [] + for req in dummy_requests: + req_ids.append(req.req_id) + new_token_ids.append([ + valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), (1, )).item()] + ]), # placeholder token + new_block_ids.append([req.block_ids]), + num_computed_tokens.append(req.num_computed_tokens) + + cached_request_data = CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=False, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) # Set up scheduler_output for execute_model scheduler_output = SchedulerOutput( scheduled_new_reqs=dummy_requests, - scheduled_cached_reqs=CachedRequestData.make_empty(), + scheduled_cached_reqs=cached_request_data, num_scheduled_tokens={i: prompt_len for i in range(batch_size)}, total_num_scheduled_tokens=sum(prompt_len @@ -565,7 +586,8 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, # The fixed size warmup needs to happen only in here with _maybe_warmup_context(): self._warmup_model_forward_pass(scheduler_output, dummy_requests, - cached_requests, num_decode_tokens) + cached_request_data, + num_decode_tokens) self.perf_metrics.log("warmup 1 time", time.time() - warmup_start_t, batch_size=batch_size, @@ -576,7 +598,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, logger.info("Warmup forward pass 2/2...") warmup2_start_t = time.time() self._warmup_model_forward_pass(scheduler_output, dummy_requests, - cached_requests, num_decode_tokens) + cached_request_data, num_decode_tokens) warmup_end_t = time.time() warmup_total_t = warmup_end_t - warmup_start_t @@ -595,12 +617,12 @@ def _warmup_model_forward_pass( self, scheduler_output: SchedulerOutput, requests: list[NewRequestData], - cached_requests: list[CachedRequestData], + cached_requests: CachedRequestData, num_decode_tokens, ): """Handle a complete forward pass""" scheduler_output.scheduled_new_reqs = requests - scheduler_output.scheduled_cached_reqs = [] + scheduler_output.scheduled_cached_reqs = CachedRequestData.make_empty() self.execute_model(scheduler_output) # Prefill # Switch to cached requests to trigger decoding steps From 989f6d20deca46053ccbbc15389da2502c10c6c7 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 2 Jul 2025 14:16:11 -0700 Subject: [PATCH 05/10] =?UTF-8?q?=F0=9F=90=9B=20first=20pass=20for=20cb?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_model_runner.py | 40 ++++++++++++++-------- vllm_spyre/v1/worker/spyre_worker.py | 4 ++- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 063d4133d..53e8922cc 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -757,49 +757,59 @@ def _prepare_prompt( def _prepare_decode( self, - cached_requests: CachedRequestData, + cached_request_data: CachedRequestData, ) -> ModelForwardInputs: - assert len(cached_requests.req_ids) > 0 + assert len(cached_request_data.req_ids) > 0 input_tokens = [] input_positions = [] block_table = [] slot_mapping = [] left_padded_prompt_mask = [] - self.model.indices = torch.ones(len(cached_requests.req_ids), + self.model.indices = torch.ones(len(cached_request_data.req_ids), dtype=torch.bool, device="cpu") assert len(self.input_batch.req_id_to_index) == len( - cached_requests.req_ids) + cached_request_data.req_ids) # TODO(wallas): I think we can do better here, without sorting or # creating an intermediary dictionary - cached_reqs_map = {c.req_id: c for c in cached_requests} + # for req in cached_request_data: + # cached_reqs_map = {c.req_id: c for c in cached_requests} + + cached_reqs_map = { + req_id: i + for i, req_id in enumerate(cached_request_data.req_ids) + } + req_ids = self.input_batch.sorted_requests_ids + # for _, req_id in enumerate(cached_request_data.req_ids): for req_id in req_ids: + # TODO: Will this always just be one token ID if there's no spec # or jump decoding? - cached_request = cached_reqs_map[req_id] + # cached_request = cached_reqs_map[req_id] # adding new blocks if needed if self.tkv // self.block_size + 1 > len( - self.req_ids2blocks[cached_request.req_id]): - self.req_ids2blocks[cached_request.req_id].append( - self.free_blocks.popleft()) - block_table.append(self.req_ids2blocks[cached_request.req_id]) + self.req_ids2blocks[req_id]): + self.req_ids2blocks[req_id].append(self.free_blocks.popleft()) + block_table.append(self.req_ids2blocks[req_id]) # slot_mapping for all blocks of sequence start_slot = block_table[-1][-1] * self.block_size offset = self.tkv % self.block_size slot = [start_slot + offset] slot_mapping.append(slot) - - generation_token = cached_request.new_token_ids[-1] + new_token_ids = cached_request_data.new_token_ids[ + cached_reqs_map[req_id]] + generation_token = new_token_ids[-1] input_tokens.append([generation_token]) - seq_len = cached_request.num_computed_tokens + seq_len = cached_request_data.num_computed_tokens[ + cached_reqs_map[req_id]] input_positions.append([seq_len]) - req_state = self.requests[cached_request.req_id] + req_state = self.requests[req_id] left_padded_prompt_mask.append(req_state.left_padding) input_tokens = torch.tensor(input_tokens, @@ -825,7 +835,7 @@ def _prepare_decode( dtype=torch.int64) # add pads for min decode batch size of 2 (Spyre compiler constraint) - if len(cached_requests) == 1: + if len(cached_request_data.req_ids) == 1: padd_seq_indices = torch.zeros(1, dtype=torch.bool, device="cpu") self.model.indices = torch.cat( (self.model.indices, padd_seq_indices), -1) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index e0d210a37..daf8d06e6 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -374,6 +374,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): req_ids = [] new_token_ids: list[list[int]] = [] new_block_ids: list[tuple[list[int], ...]] = [] + num_computed_tokens = [] for req in dummy_requests: req_ids.append(req.req_id) new_token_ids.append([ @@ -381,12 +382,13 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): 0, len(valid_token_ids_tensor), (1, )).item()] ]), # placeholder token new_block_ids.append([req.block_ids]), + num_computed_tokens.append(prompt_len), cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, new_token_ids=new_token_ids, new_block_ids=new_block_ids, - num_computed_tokens=[prompt_len], + num_computed_tokens=num_computed_tokens, ) scheduler_output = SchedulerOutput( From bd7e008e94140ebaf6ebe0c0bcfaf710425c22af Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 2 Jul 2025 14:53:46 -0700 Subject: [PATCH 06/10] =?UTF-8?q?=F0=9F=90=9B=20fix=20merge=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_worker.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 1cb6e37c8..136135a92 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -395,10 +395,8 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=cached_request_data, - num_scheduled_tokens={ - f"warmup-{i}": 1 - for i in range(batch_size) - }, + num_scheduled_tokens={f"warmup-{i}": 1 + for i in range(batch_size)}, total_num_scheduled_tokens=batch_size, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, @@ -449,7 +447,7 @@ def _cleanup_model_runner(self, request) -> None: # Needed to clean up the data of model runner scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, # NOTE: this means no work to do total_num_scheduled_tokens=0, From fe8e64c8a679a3f0fa79130ecf805347fec142f2 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 3 Jul 2025 09:30:06 -0700 Subject: [PATCH 07/10] =?UTF-8?q?=F0=9F=8E=A8=20renaming=20vars?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_model_runner.py | 10 +++++----- vllm_spyre/v1/worker/spyre_worker.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 53e8922cc..1776f9068 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -457,18 +457,18 @@ def _prepare_prompt( def _prepare_decode( self, - cached_requests: CachedRequestData, + cached_request_data: CachedRequestData, ) -> ModelForwardInputs: - assert len(cached_requests.req_ids) > 0 + assert len(cached_request_data.req_ids) > 0 input_tokens: list[list[int]] = [ [0] for _ in range(self._position_ids.shape[0]) ] - for i, req_id in enumerate(cached_requests.req_ids): - # for cached_request in cached_requests: + for i, req_id in enumerate(cached_request_data.req_ids): + # for cached_request in cached_request_data: # TODO: Will this always just be one token ID if there's no spec # or jump decoding? - new_token_ids = cached_requests.new_token_ids[i] + new_token_ids = cached_request_data.new_token_ids[i] generation_token = new_token_ids[-1] input_tokens[self.input_batch.req_id_to_index[req_id]] = [ generation_token diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 136135a92..94371e130 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -642,7 +642,7 @@ def _warmup_model_forward_pass( self, scheduler_output: SchedulerOutput, requests: list[NewRequestData], - cached_requests: CachedRequestData, + cached_request_data: CachedRequestData, num_decode_tokens, ): """Handle a complete forward pass""" @@ -652,7 +652,7 @@ def _warmup_model_forward_pass( # Switch to cached requests to trigger decoding steps scheduler_output.scheduled_new_reqs = [] - scheduler_output.scheduled_cached_reqs = cached_requests + scheduler_output.scheduled_cached_reqs = cached_request_data for _ in range(num_decode_tokens - 1): self.execute_model(scheduler_output) From df9214b3b257d5872cd5609cdad5d20ccd191b51 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 3 Jul 2025 09:36:39 -0700 Subject: [PATCH 08/10] =?UTF-8?q?=F0=9F=94=A5=20remove=20commented=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_model_runner.py | 12 -------- vllm_spyre/v1/worker/spyre_worker.py | 32 +++------------------- 2 files changed, 4 insertions(+), 40 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 1776f9068..d3847f00f 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -245,10 +245,6 @@ def update_states(self, scheduler_output: SchedulerOutput): for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] - # for req_data in scheduler_output.scheduled_cached_reqs: - # req_id = req_data.req_ids[0] - # req_state = self.requests[req_id] - # Update the cached states. num_computed_tokens = req_data.num_computed_tokens[i] new_token_ids = req_data.new_token_ids[i] @@ -465,7 +461,6 @@ def _prepare_decode( ] for i, req_id in enumerate(cached_request_data.req_ids): - # for cached_request in cached_request_data: # TODO: Will this always just be one token ID if there's no spec # or jump decoding? new_token_ids = cached_request_data.new_token_ids[i] @@ -774,22 +769,15 @@ def _prepare_decode( cached_request_data.req_ids) # TODO(wallas): I think we can do better here, without sorting or # creating an intermediary dictionary - # for req in cached_request_data: - # cached_reqs_map = {c.req_id: c for c in cached_requests} - cached_reqs_map = { req_id: i for i, req_id in enumerate(cached_request_data.req_ids) } - req_ids = self.input_batch.sorted_requests_ids - # for _, req_id in enumerate(cached_request_data.req_ids): for req_id in req_ids: - # TODO: Will this always just be one token ID if there's no spec # or jump decoding? - # cached_request = cached_reqs_map[req_id] # adding new blocks if needed if self.tkv // self.block_size + 1 > len( diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 94371e130..846dfb315 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -360,21 +360,9 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): self.execute_model(scheduler_output) # one decode iteration across both sequences - # cached_requests = [ - # CachedRequestData( - # req_ids=[req.req_id], - # resumed_from_preemption=False, - # new_token_ids=[ - # valid_token_ids_tensor[torch.randint( - # 0, len(valid_token_ids_tensor), (1, )).item()] - # ], # placeholder token - # new_block_ids=req.block_ids, - # num_computed_tokens=prompt_len, - # ) for req in dummy_requests - # ] req_ids = [] - new_token_ids: list[list[int]] = [] - new_block_ids: list[tuple[list[int], ...]] = [] + new_token_ids = [] + new_block_ids = [] num_computed_tokens = [] for req in dummy_requests: req_ids.append(req.req_id) @@ -556,21 +544,9 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, ] # Set up dummy cached_requests for decode steps - # cached_requests = [ - # CachedRequestData( - # req_ids=[req.req_id], - # resumed_from_preemption=False, - # new_token_ids=[ - # valid_token_ids_tensor[torch.randint( - # 0, len(valid_token_ids_tensor), (1, )).item()] - # ], # placeholder token - # new_block_ids=req.block_ids, - # num_computed_tokens=req.num_computed_tokens, - # ) for req in dummy_requests - # ] req_ids = [] - new_token_ids: list[list[int]] = [] - new_block_ids: list[tuple[list[int], ...]] = [] + new_token_ids = [] + new_block_ids = [] num_computed_tokens = [] for req in dummy_requests: req_ids.append(req.req_id) From f06578535519b759e1d580e6698efc9112b6d14b Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 3 Jul 2025 10:16:43 -0700 Subject: [PATCH 09/10] =?UTF-8?q?=F0=9F=94=A5=20remove=20extra=20commas?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_worker.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 846dfb315..a84596abf 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -369,9 +369,9 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): new_token_ids.append([ valid_token_ids_tensor[torch.randint( 0, len(valid_token_ids_tensor), (1, )).item()] - ]), # placeholder token - new_block_ids.append([req.block_ids]), - num_computed_tokens.append(prompt_len), + ]) # placeholder token + new_block_ids.append([req.block_ids]) + num_computed_tokens.append(prompt_len) cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, @@ -553,8 +553,8 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, new_token_ids.append([ valid_token_ids_tensor[torch.randint( 0, len(valid_token_ids_tensor), (1, )).item()] - ]), # placeholder token - new_block_ids.append([req.block_ids]), + ]) # placeholder token + new_block_ids.append([req.block_ids]) num_computed_tokens.append(req.num_computed_tokens) cached_request_data = CachedRequestData( From d96533859c1aea201524860e91f05b0e9f670082 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 3 Jul 2025 11:33:07 -0700 Subject: [PATCH 10/10] =?UTF-8?q?=F0=9F=9A=A7=20wip=20to=20see=20if=20test?= =?UTF-8?q?s=20pass?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 760c79d21..1b97a6587 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,7 +36,7 @@ jobs: - name: "default" repo: "" - name: "vLLM:main" - repo: "git+https://github.com/vllm-project/vllm --branch main" + repo: "git+https://github.com/vllm-project/vllm@02cabff207ca68094a73ba21296c82cdbcb1d1a5" test_suite: - name: "static batching" markers: "cpu and decoder and not cb"