diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 09522b5bcd..9f2a785264 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -8,7 +8,6 @@ from tpu_info import device from vllm.inputs import ProcessorInputs, PromptType from vllm.platforms.interface import Platform, PlatformEnum -from vllm.sampling_params import SamplingParams, SamplingType from tpu_inference import envs from tpu_inference.layers.common.sharding import ShardingConfigManager @@ -19,12 +18,15 @@ from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import BlockSize, ModelConfig, VllmConfig from vllm.pooling_params import PoolingParams + from vllm.sampling_params import SamplingParams, SamplingType else: BlockSize = None ModelConfig = None VllmConfig = None PoolingParams = None AttentionBackendEnum = None + SamplingParams = None + SamplingType = None logger = init_logger(__name__) @@ -250,10 +252,11 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: def validate_request( cls, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: Union["SamplingParams", PoolingParams], processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" + from vllm.sampling_params import SamplingParams, SamplingType if isinstance(params, SamplingParams): if params.sampling_type == SamplingType.RANDOM_SEED: