diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 760c79d21..1b97a6587 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,7 +36,7 @@ jobs: - name: "default" repo: "" - name: "vLLM:main" - repo: "git+https://github.com/vllm-project/vllm --branch main" + repo: "git+https://github.com/vllm-project/vllm@02cabff207ca68094a73ba21296c82cdbcb1d1a5" test_suite: - name: "static batching" markers: "cpu and decoder and not cb" diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 17abe4129..d3847f00f 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -241,32 +241,31 @@ def update_states(self, scheduler_output: SchedulerOutput): # # NOTE: req_state.output_token_ids is being mutated. - for req_data in scheduler_output.scheduled_cached_reqs: - req_id = req_data.req_id + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] # Update the cached states. - num_computed_tokens = req_data.num_computed_tokens + num_computed_tokens = req_data.num_computed_tokens[i] + new_token_ids = req_data.new_token_ids[i] # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec decode tokens. - num_new_tokens = (num_computed_tokens + - len(req_data.new_token_ids) - + num_new_tokens = (num_computed_tokens + len(new_token_ids) - req_state.num_tokens) if num_new_tokens == 1: # Avoid slicing list in most common case. - req_state.output_token_ids.append(req_data.new_token_ids[-1]) + req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: req_state.output_token_ids.extend( - req_data.new_token_ids[-num_new_tokens:]) + new_token_ids[-num_new_tokens:]) req_index = self.input_batch.get_req_index(req_id) # Add new_token_ids to token_ids_cpu. # TODO: Update for spec decoding in the future start_token_index = num_computed_tokens - end_token_index = num_computed_tokens + len(req_data.new_token_ids) + end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = req_data.new_token_ids + req_index, start_token_index:end_token_index] = new_token_ids if scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids: @@ -277,8 +276,7 @@ def update_states(self, scheduler_output: SchedulerOutput): def _prepare_prompt(self, _: list[NewRequestData]) -> ModelForwardInputs: raise NotImplementedError - def _prepare_decode(self, - _: list[CachedRequestData]) -> ModelForwardInputs: + def _prepare_decode(self, _: CachedRequestData) -> ModelForwardInputs: raise NotImplementedError def prepare_model_input( @@ -291,7 +289,7 @@ def prepare_model_input( # Prepare input tensors. if is_prompt: # Assert no running requests - assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs.req_ids) == 0 return self._prepare_prompt(scheduler_output.scheduled_new_reqs) else: @@ -455,19 +453,21 @@ def _prepare_prompt( def _prepare_decode( self, - cached_requests: list[CachedRequestData], + cached_request_data: CachedRequestData, ) -> ModelForwardInputs: - assert len(cached_requests) > 0 + assert len(cached_request_data.req_ids) > 0 input_tokens: list[list[int]] = [ [0] for _ in range(self._position_ids.shape[0]) ] - for cached_request in cached_requests: + for i, req_id in enumerate(cached_request_data.req_ids): # TODO: Will this always just be one token ID if there's no spec # or jump decoding? - generation_token = cached_request.new_token_ids[-1] - input_tokens[self.input_batch.req_id_to_index[ - cached_request.req_id]] = [generation_token] + new_token_ids = cached_request_data.new_token_ids[i] + generation_token = new_token_ids[-1] + input_tokens[self.input_batch.req_id_to_index[req_id]] = [ + generation_token + ] # update position ids and attention mask self._update_position_ids() @@ -752,48 +752,52 @@ def _prepare_prompt( def _prepare_decode( self, - cached_requests: list[CachedRequestData], + cached_request_data: CachedRequestData, ) -> ModelForwardInputs: - assert len(cached_requests) > 0 + assert len(cached_request_data.req_ids) > 0 input_tokens = [] input_positions = [] block_table = [] slot_mapping = [] left_padded_prompt_mask = [] - self.model.indices = torch.ones(len(cached_requests), + self.model.indices = torch.ones(len(cached_request_data.req_ids), dtype=torch.bool, device="cpu") - assert len(self.input_batch.req_id_to_index) == len(cached_requests) + assert len(self.input_batch.req_id_to_index) == len( + cached_request_data.req_ids) # TODO(wallas): I think we can do better here, without sorting or # creating an intermediary dictionary - cached_reqs_map = {c.req_id: c for c in cached_requests} + cached_reqs_map = { + req_id: i + for i, req_id in enumerate(cached_request_data.req_ids) + } req_ids = self.input_batch.sorted_requests_ids for req_id in req_ids: # TODO: Will this always just be one token ID if there's no spec # or jump decoding? - cached_request = cached_reqs_map[req_id] # adding new blocks if needed if self.tkv // self.block_size + 1 > len( - self.req_ids2blocks[cached_request.req_id]): - self.req_ids2blocks[cached_request.req_id].append( - self.free_blocks.popleft()) - block_table.append(self.req_ids2blocks[cached_request.req_id]) + self.req_ids2blocks[req_id]): + self.req_ids2blocks[req_id].append(self.free_blocks.popleft()) + block_table.append(self.req_ids2blocks[req_id]) # slot_mapping for all blocks of sequence start_slot = block_table[-1][-1] * self.block_size offset = self.tkv % self.block_size slot = [start_slot + offset] slot_mapping.append(slot) - - generation_token = cached_request.new_token_ids[-1] + new_token_ids = cached_request_data.new_token_ids[ + cached_reqs_map[req_id]] + generation_token = new_token_ids[-1] input_tokens.append([generation_token]) - seq_len = cached_request.num_computed_tokens + seq_len = cached_request_data.num_computed_tokens[ + cached_reqs_map[req_id]] input_positions.append([seq_len]) - req_state = self.requests[cached_request.req_id] + req_state = self.requests[req_id] left_padded_prompt_mask.append(req_state.left_padding) input_tokens = torch.tensor(input_tokens, @@ -819,7 +823,7 @@ def _prepare_decode( dtype=torch.int64) # add pads for min decode batch size of 2 (Spyre compiler constraint) - if len(cached_requests) == 1: + if len(cached_request_data.req_ids) == 1: padd_seq_indices = torch.zeros(1, dtype=torch.bool, device="cpu") self.model.indices = torch.cat( (self.model.indices, padd_seq_indices), -1) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 7de7b15bd..a84596abf 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -345,7 +345,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): for i, req in enumerate(dummy_requests): scheduler_output = SchedulerOutput( scheduled_new_reqs=[req], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={req.req_id: prompt_len}, total_num_scheduled_tokens=prompt_len, scheduled_spec_decode_tokens={}, @@ -359,45 +359,50 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): logger.info("Warmup prefill %d/%d...", i + 1, batch_size) self.execute_model(scheduler_output) - # one decode iteration across both sequences - cached_requests = [ - CachedRequestData( - req_id=req.req_id, - resumed_from_preemption=False, - new_token_ids=[ - valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (1, )).item()] - ], # placeholder token - new_block_ids=req.block_ids, - num_computed_tokens=prompt_len, - ) for req in dummy_requests - ] - - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=cached_requests, - num_scheduled_tokens={ - f"warmup-{i}": 1 - for i in range(batch_size) - }, - total_num_scheduled_tokens=batch_size, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None, - ) - logger.info("Warmup decode 1/1...") - self.execute_model(scheduler_output) - self._cleanup_model_runner(request=dummy_requests) + # one decode iteration across both sequences + req_ids = [] + new_token_ids = [] + new_block_ids = [] + num_computed_tokens = [] + for req in dummy_requests: + req_ids.append(req.req_id) + new_token_ids.append([ + valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), (1, )).item()] + ]) # placeholder token + new_block_ids.append([req.block_ids]) + num_computed_tokens.append(prompt_len) + cached_request_data = CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=False, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_request_data, + num_scheduled_tokens={f"warmup-{i}": 1 + for i in range(batch_size)}, + total_num_scheduled_tokens=batch_size, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + logger.info("Warmup decode 1/1...") + self.execute_model(scheduler_output) + self._cleanup_model_runner(request=dummy_requests) # doing one additional prefill outside the warmup_context seems to be # necessary to have reasonable TTFT for the first prefill after warmup scheduler_output = SchedulerOutput( scheduled_new_reqs=[add_dummy_request], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={add_dummy_request.req_id: prompt_len}, total_num_scheduled_tokens=prompt_len, scheduled_spec_decode_tokens={}, @@ -430,7 +435,7 @@ def _cleanup_model_runner(self, request) -> None: # Needed to clean up the data of model runner scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, # NOTE: this means no work to do total_num_scheduled_tokens=0, @@ -539,23 +544,31 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, ] # Set up dummy cached_requests for decode steps - cached_requests = [ - CachedRequestData( - req_id=req.req_id, - resumed_from_preemption=False, - new_token_ids=[ - valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (1, )).item()] - ], # placeholder token - new_block_ids=req.block_ids, - num_computed_tokens=req.num_computed_tokens, - ) for req in dummy_requests - ] + req_ids = [] + new_token_ids = [] + new_block_ids = [] + num_computed_tokens = [] + for req in dummy_requests: + req_ids.append(req.req_id) + new_token_ids.append([ + valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), (1, )).item()] + ]) # placeholder token + new_block_ids.append([req.block_ids]) + num_computed_tokens.append(req.num_computed_tokens) + + cached_request_data = CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=False, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) # Set up scheduler_output for execute_model scheduler_output = SchedulerOutput( scheduled_new_reqs=dummy_requests, - scheduled_cached_reqs=[], + scheduled_cached_reqs=cached_request_data, num_scheduled_tokens={i: prompt_len for i in range(batch_size)}, total_num_scheduled_tokens=sum(prompt_len @@ -574,7 +587,8 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, # The fixed size warmup needs to happen only in here with _maybe_warmup_context(): self._warmup_model_forward_pass(scheduler_output, dummy_requests, - cached_requests, num_decode_tokens) + cached_request_data, + num_decode_tokens) self.perf_metrics.log("warmup 1 time", time.time() - warmup_start_t, batch_size=batch_size, @@ -585,7 +599,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, logger.info("Warmup forward pass 2/2...") warmup2_start_t = time.time() self._warmup_model_forward_pass(scheduler_output, dummy_requests, - cached_requests, num_decode_tokens) + cached_request_data, num_decode_tokens) warmup_end_t = time.time() warmup_total_t = warmup_end_t - warmup_start_t @@ -604,17 +618,17 @@ def _warmup_model_forward_pass( self, scheduler_output: SchedulerOutput, requests: list[NewRequestData], - cached_requests: list[CachedRequestData], + cached_request_data: CachedRequestData, num_decode_tokens, ): """Handle a complete forward pass""" scheduler_output.scheduled_new_reqs = requests - scheduler_output.scheduled_cached_reqs = [] + scheduler_output.scheduled_cached_reqs = CachedRequestData.make_empty() self.execute_model(scheduler_output) # Prefill # Switch to cached requests to trigger decoding steps scheduler_output.scheduled_new_reqs = [] - scheduler_output.scheduled_cached_reqs = cached_requests + scheduler_output.scheduled_cached_reqs = cached_request_data for _ in range(num_decode_tokens - 1): self.execute_model(scheduler_output)