Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions vllm/v1/sample/logits_processor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools
from abc import abstractmethod
from collections.abc import Sequence
from functools import partial
from functools import lru_cache, partial
from typing import TYPE_CHECKING

import torch
Expand Down Expand Up @@ -216,11 +216,17 @@ def build_logitsprocs(
)


cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While using lru_cache on _load_custom_logitsprocs is a good performance improvement, it can be made more efficient. _load_custom_logitsprocs internally calls _load_logitsprocs_plugins(), which does not depend on any arguments. With the current implementation, _load_logitsprocs_plugins() will be re-executed for every cache miss of cached_load_custom_logitsprocs (i.e., for each new logits_processors value).

To avoid this repeated work, _load_logitsprocs_plugins() should be cached independently. The ideal solution would be to apply @lru_cache directly to _load_logitsprocs_plugins and _load_logitsprocs_by_fqcns. This would require modifying those functions, which are outside the current diff.

For example:

@lru_cache(maxsize=None)
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
    # ... function body

@lru_cache(maxsize=None)
def _load_logitsprocs_by_fqcns(
    logits_processors: tuple[str | type[LogitsProcessor], ...] | None,
) -> list[type[LogitsProcessor]]:
    # ... function body

Then _load_custom_logitsprocs can call these cached functions, and validate_logits_processors_parameters can call _load_custom_logitsprocs directly without an extra caching layer. This would be the most efficient implementation.



def validate_logits_processors_parameters(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
sampling_params: SamplingParams,
):
for logits_procs in _load_custom_logitsprocs(logits_processors):
logits_processors = (
tuple(logits_processors) if logits_processors is not None else None
)
for logits_procs in cached_load_custom_logitsprocs(logits_processors):
logits_procs.validate_params(sampling_params)


Expand Down