diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 0f2e0d2e31..83213ba316 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Optional, Tuple, Union, cast import jax.numpy as jnp import torch @@ -15,6 +15,7 @@ if TYPE_CHECKING: from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import BlockSize, ModelConfig, VllmConfig from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType @@ -51,11 +52,10 @@ class TpuPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum", - head_size: int, dtype: jnp.dtype, - kv_cache_dtype: Optional[str], block_size: int, - use_mla: bool, has_sink: bool, use_sparse: bool, - use_mm_prefix: bool, attn_type: Any) -> str: + attn_selector_config: "AttentionSelectorConfig", + **kwargs) -> str: from vllm.attention.backends.registry import AttentionBackendEnum + if selected_backend != AttentionBackendEnum.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend)