diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 12e4215038d74..dd122f9f1272b 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -543,7 +543,8 @@ def execute_model( seq_group_metadata_list=ctx.seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs, is_async=False, - is_last_step=False) + is_last_step=False, + is_first_step_output=i == 0) model_input.async_callback() if use_async_out_proc: return [sampler_outputs[-1]]