Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast

import jax.numpy as jnp
import torch
Expand All @@ -15,6 +15,7 @@

if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import BlockSize, ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
Expand Down Expand Up @@ -51,11 +52,10 @@ class TpuPlatform(Platform):

@classmethod
def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
head_size: int, dtype: jnp.dtype,
kv_cache_dtype: Optional[str], block_size: int,
use_mla: bool, has_sink: bool, use_sparse: bool,
use_mm_prefix: bool, attn_type: Any) -> str:
attn_selector_config: "AttentionSelectorConfig",
**kwargs) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum

if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)

Expand Down