diff --git a/vllm_hpu/v1/worker/hpu_model_runner.py b/vllm_hpu/v1/worker/hpu_model_runner.py index 1300eb2c3..fe66b1df7 100644 --- a/vllm_hpu/v1/worker/hpu_model_runner.py +++ b/vllm_hpu/v1/worker/hpu_model_runner.py @@ -44,6 +44,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm_hpu.v1.worker.hpu_input_batch import InputBatch +from vllm.distributed.parallel_state import get_pp_group if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -768,26 +769,30 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] - new_token_ids = req_data.new_token_ids[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] req_state.num_computed_tokens = num_computed_tokens - # Update the cached states. - # Add the sampled token(s) from the previous step (if any). - # This doesn't include "unverified" tokens like spec decode tokens. - num_new_tokens = (num_computed_tokens + len(new_token_ids) - - req_state.num_tokens) - if num_new_tokens == 1: - # Avoid slicing list in most common case. - req_state.output_token_ids.append(new_token_ids[-1]) - elif num_new_tokens > 0: - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = (num_computed_tokens + len(new_token_ids) - + req_state.num_tokens) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:]) # Update the block IDs. if not resumed_from_preemption: @@ -809,22 +814,26 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) self.input_batch.block_table.append_row(new_block_ids, req_index) - # Add new_token_ids to token_ids_cpu. - start_token_index = num_computed_tokens - end_token_index = num_computed_tokens + len(new_token_ids) - self.input_batch.token_ids_cpu[ - req_index, start_token_index:end_token_index] = new_token_ids - self.input_batch.num_tokens_no_spec[req_index] = end_token_index - # Add spec_token_ids to token_ids_cpu. - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, ()) - if spec_token_ids: - start_index = end_token_index - end_token_index += len(spec_token_ids) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index] = spec_token_ids - # NOTE(woosuk): `num_tokens` here may include spec decode tokens. - self.input_batch.num_tokens[req_index] = end_token_index + req_index, start_token_index:end_token_index] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, ()) + if spec_token_ids: + start_index = end_token_index + end_token_index += len(spec_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec decode tokens. + self.input_batch.num_tokens[req_index] = end_token_index # Check if the batch has changed. If not, we can skip copying the # sampling metadata from CPU to GPU. @@ -1654,6 +1663,30 @@ def execute_model( self.input_batch.num_tokens[i] += len(token_ids) req_state.output_token_ids.extend(token_ids) + # NOTE(chendi): enable cache based on PR(#20291) + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx, sampled_ids in enumerate(postprocessed_sampled_token_ids[:num_reqs]): + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.max_model_len}") + + self.input_batch.token_ids_cpu[req_idx, + start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) ################## RETURN ################## # Create output. all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids