diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 2f5bd536b..95c60f249 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -357,44 +357,46 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): logger.info("[WARMUP] Prefill %d/%d...", i + 1, batch_size) self.execute_model(scheduler_output) - # one decode iteration across all sequences - req_ids = [] - new_token_ids = [] - new_block_ids = [] - num_computed_tokens = [] - for req in dummy_requests: - req_ids.append(req.req_id) - new_token_ids.append([ - valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (1, )).item()] - ]) # placeholder token - new_block_ids.append([req.block_ids]) - num_computed_tokens.append(prompt_len) - cached_request_data = CachedRequestData( - req_ids=req_ids, - resumed_from_preemption=False, - new_token_ids=new_token_ids, - new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, - ) + # one decode iteration across all sequences + req_ids = [] + new_token_ids = [] + new_block_ids = [] + num_computed_tokens = [] + for req in dummy_requests: + req_ids.append(req.req_id) + new_token_ids.append([ + valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), (1, )).item()] + ]) # placeholder token + new_block_ids.append([req.block_ids]) + num_computed_tokens.append(prompt_len) + cached_request_data = CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=False, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=cached_request_data, - num_scheduled_tokens={f"warmup-{i}": 1 - for i in range(batch_size)}, - total_num_scheduled_tokens=batch_size, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None, - ) - logger.info("[WARMUP] Decode...") - self.execute_model(scheduler_output) - self._cleanup_model_runner(request=dummy_requests) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_request_data, + num_scheduled_tokens={ + f"warmup-{i}": 1 + for i in range(batch_size) + }, + total_num_scheduled_tokens=batch_size, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + logger.info("[WARMUP] Decode...") + self.execute_model(scheduler_output) + self._cleanup_model_runner(request=dummy_requests) # warmup_mode completes the graph compilation, but we need to do # one additional prefill to deploy the compiled program to the device,