diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index eb537eae6c90..5992c4066c9c 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -5,7 +5,7 @@ import itertools from abc import abstractmethod from collections.abc import Sequence -from functools import partial +from functools import lru_cache, partial from typing import TYPE_CHECKING import torch @@ -216,11 +216,17 @@ def build_logitsprocs( ) +cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs) + + def validate_logits_processors_parameters( logits_processors: Sequence[str | type[LogitsProcessor]] | None, sampling_params: SamplingParams, ): - for logits_procs in _load_custom_logitsprocs(logits_processors): + logits_processors = ( + tuple(logits_processors) if logits_processors is not None else None + ) + for logits_procs in cached_load_custom_logitsprocs(logits_processors): logits_procs.validate_params(sampling_params)