Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/vllm_lkg.version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
8c0b6267d7fa5c8a07e318809180fc021a0afbf2
d9408ffba3c8da5e289b5695d507707afce10a2f
3 changes: 3 additions & 0 deletions tests/runner/test_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def meta_eq(a: PoolingMetadata, b: PoolingMetadata):
PoolingMetadata(
prompt_lens=torch.tensor([], dtype=torch.int32),
prompt_token_ids=None,
prompt_token_ids_cpu=None,
pooling_params=[],
pooling_states=[],
),
Expand All @@ -281,6 +282,7 @@ def meta_eq(a: PoolingMetadata, b: PoolingMetadata):
PoolingMetadata(
prompt_lens=torch.tensor([10], dtype=torch.int32),
prompt_token_ids=None,
prompt_token_ids_cpu=None,
pooling_params=[pooling_param],
pooling_states=[pooling_state],
),
Expand All @@ -292,6 +294,7 @@ def meta_eq(a: PoolingMetadata, b: PoolingMetadata):
PoolingMetadata(
prompt_lens=torch.tensor([], dtype=torch.int32),
prompt_token_ids=None,
prompt_token_ids_cpu=None,
pooling_params=[],
pooling_states=[],
),
Expand Down
5 changes: 3 additions & 2 deletions tpu_inference/executors/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from vllm.v1.executor.ray_distributed_executor import \
RayDistributedExecutor as RayDistributedExecutorV1
from vllm.v1.executor.ray_executor import RayWorkerMetaData
from vllm.v1.executor.ray_utils import WORKER_SPECIFIC_ENV_VARS
from vllm.v1.executor.ray_utils import RayWorkerWrapper as RayWorkerWrapperV1
from vllm.v1.executor.ray_utils import _wait_until_pg_ready
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -363,7 +364,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):

# Environment variables to copy from driver to workers
env_vars_to_copy = get_env_vars_to_copy(
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
exclude_vars=WORKER_SPECIFIC_ENV_VARS,
additional_vars=set(current_platform.additional_env_vars),
destination="workers")

Expand Down Expand Up @@ -483,7 +484,7 @@ class RayWorkerWrapper(RayWorkerWrapperV1):
Ray worker wrapper for TPU.

The implementation is similar to vllm/v1/executor/ray_utils.py

_is_intermediate_tensors: check whether the output is JaxIntermediateTensors.
_is_last_rank: check whether this Ray worker is the last PP stage.
"""
Expand Down
1 change: 1 addition & 0 deletions tpu_inference/runner/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def get_pooling_metadata(self) -> PoolingMetadata:
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]),
prompt_token_ids=None,
prompt_token_ids_cpu=None,
pooling_params=pooling_params,
pooling_states=pooling_states,
)
Expand Down
Loading