diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index fbf6a3f17..313cce320 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -363,8 +363,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): with _maybe_warmup_context(): self._dynamic_warmup(dummy_requests=dummy_requests, prompt_len=prompt_len, - batch_size=batch_size, - valid_token_ids_tensor=valid_token_ids_tensor) + batch_size=batch_size) # warmup_mode completes the graph compilation, but we need to do # one additional prefill to deploy the compiled program to the device, @@ -466,22 +465,17 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, # Set up dummy cached_requests for decode steps req_ids = [] - new_token_ids = [] new_block_ids = [] num_computed_tokens = [] for req in dummy_requests: req_ids.append(req.req_id) - new_token_ids.append([ - valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (1, )).item()] - ]) # placeholder token new_block_ids.append([req.block_ids]) num_computed_tokens.append(req.num_computed_tokens) cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, - new_token_ids=new_token_ids, + new_token_ids=[[] for _ in range(len(dummy_requests))], new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, ) @@ -542,7 +536,6 @@ def _dynamic_warmup( dummy_requests: list[NewRequestData], prompt_len: int, batch_size: int, - valid_token_ids_tensor: torch.Tensor, ) -> None: assert ( @@ -569,21 +562,16 @@ def _dynamic_warmup( # one decode iteration across all sequences req_ids = [] - new_token_ids = [] new_block_ids = [] num_computed_tokens = [] for req in dummy_requests: req_ids.append(req.req_id) - new_token_ids.append([ - valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (1, )).item()] - ]) # placeholder token new_block_ids.append([req.block_ids]) num_computed_tokens.append(prompt_len) cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, - new_token_ids=new_token_ids, + new_token_ids=[[] for _ in range(len(dummy_requests))], new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, )