diff --git a/.buildkite/vllm_lkg.version b/.buildkite/vllm_lkg.version index 86b654098b..93bc509737 100644 --- a/.buildkite/vllm_lkg.version +++ b/.buildkite/vllm_lkg.version @@ -1 +1 @@ -8c0b6267d7fa5c8a07e318809180fc021a0afbf2 +d9408ffba3c8da5e289b5695d507707afce10a2f diff --git a/tests/runner/test_input_batch.py b/tests/runner/test_input_batch.py index 4a218cdb9c..8fe0cbdf0b 100644 --- a/tests/runner/test_input_batch.py +++ b/tests/runner/test_input_batch.py @@ -261,6 +261,7 @@ def meta_eq(a: PoolingMetadata, b: PoolingMetadata): PoolingMetadata( prompt_lens=torch.tensor([], dtype=torch.int32), prompt_token_ids=None, + prompt_token_ids_cpu=None, pooling_params=[], pooling_states=[], ), @@ -281,6 +282,7 @@ def meta_eq(a: PoolingMetadata, b: PoolingMetadata): PoolingMetadata( prompt_lens=torch.tensor([10], dtype=torch.int32), prompt_token_ids=None, + prompt_token_ids_cpu=None, pooling_params=[pooling_param], pooling_states=[pooling_state], ), @@ -292,6 +294,7 @@ def meta_eq(a: PoolingMetadata, b: PoolingMetadata): PoolingMetadata( prompt_lens=torch.tensor([], dtype=torch.int32), prompt_token_ids=None, + prompt_token_ids_cpu=None, pooling_params=[], pooling_states=[], ), diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index 3042bf90b6..4dbea3495e 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -32,6 +32,7 @@ from vllm.v1.executor.ray_distributed_executor import \ RayDistributedExecutor as RayDistributedExecutorV1 from vllm.v1.executor.ray_executor import RayWorkerMetaData +from vllm.v1.executor.ray_utils import WORKER_SPECIFIC_ENV_VARS from vllm.v1.executor.ray_utils import RayWorkerWrapper as RayWorkerWrapperV1 from vllm.v1.executor.ray_utils import _wait_until_pg_ready from vllm.v1.outputs import ModelRunnerOutput @@ -363,7 +364,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): # Environment variables to copy from driver to workers env_vars_to_copy = get_env_vars_to_copy( - exclude_vars=self.WORKER_SPECIFIC_ENV_VARS, + exclude_vars=WORKER_SPECIFIC_ENV_VARS, additional_vars=set(current_platform.additional_env_vars), destination="workers") @@ -483,7 +484,7 @@ class RayWorkerWrapper(RayWorkerWrapperV1): Ray worker wrapper for TPU. The implementation is similar to vllm/v1/executor/ray_utils.py - + _is_intermediate_tensors: check whether the output is JaxIntermediateTensors. _is_last_rank: check whether this Ray worker is the last PP stage. """ diff --git a/tpu_inference/runner/input_batch.py b/tpu_inference/runner/input_batch.py index 3e658c0d2b..392328bcb9 100644 --- a/tpu_inference/runner/input_batch.py +++ b/tpu_inference/runner/input_batch.py @@ -163,6 +163,7 @@ def get_pooling_metadata(self) -> PoolingMetadata: prompt_lens=torch.from_numpy( self.num_prompt_tokens[:self.num_reqs]), prompt_token_ids=None, + prompt_token_ids_cpu=None, pooling_params=pooling_params, pooling_states=pooling_states, )