Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions docs/features/custom_logitsprocs.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ In vLLM, logits processors operate at batch granularity. During a given engine s

Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods:

* `validate_params(cls, sampling_params: SamplingParams)`:
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
* **Note:** it's important to implement `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor.

* `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)`
* `vllm_config`: engine configuration data structure
* `device`: hardware accelerator device info
Expand All @@ -38,11 +43,6 @@ Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsPr
* Use the `BatchUpdate` members to update logits processor internal state
* **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added.

* `validate_params(cls, sampling_params: SamplingParams)`:
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
* **Note:** it's important to implent `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor.

### How the vLLM engine builds the `BatchUpdate` data structure

!!! important
Expand Down Expand Up @@ -108,6 +108,14 @@ The contrived example below implements a custom logits processor which consumes
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")

def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
self.req_info: dict[int, int] = {}
Expand Down Expand Up @@ -164,14 +172,6 @@ The contrived example below implements a custom logits processor which consumes

return logits

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")

```

In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor.
Expand Down Expand Up @@ -241,9 +241,6 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
"""Example of wrapping a fake request-level logit processor to create a
batch-level logits processor"""

def is_argmax_invariant(self) -> bool:
return False

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
Expand All @@ -254,6 +251,9 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
f"target_token value {target_token} is not int"
)

def is_argmax_invariant(self) -> bool:
return False

def new_req_logits_processor(
self,
params: SamplingParams,
Expand Down
22 changes: 12 additions & 10 deletions examples/offline_inference/logits_processor/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class object.
------------------------------------------------------------
"""

from typing import Any

import torch

from vllm import LLM, SamplingParams
Expand All @@ -48,6 +50,16 @@ class object.
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)

def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
Expand Down Expand Up @@ -89,16 +101,6 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:

return logits

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)


# Sample prompts.
prompts = [
Expand Down
6 changes: 3 additions & 3 deletions examples/offline_inference/logits_processor/custom_req.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of wrapping a fake request-level logit processor to create a
batch-level logits processor"""

def is_argmax_invariant(self) -> bool:
return False

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
Expand All @@ -87,6 +84,9 @@ def validate_params(cls, params: SamplingParams):
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")

def is_argmax_invariant(self) -> bool:
return False

def new_req_logits_processor(
self,
params: SamplingParams,
Expand Down
16 changes: 8 additions & 8 deletions examples/offline_inference/logits_processor/custom_req_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type"""

@classmethod
def validate_params(cls, params: SamplingParams):
target_token = params.extra_args and params.extra_args.get("target_token")
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"`target_token` has to be an integer, got {target_token}."
)

def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
Expand All @@ -86,14 +94,6 @@ def __init__(
def is_argmax_invariant(self) -> bool:
return False

@classmethod
def validate_params(cls, params: SamplingParams):
target_token = params.extra_args and params.extra_args.get("target_token")
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"`target_token` has to be an integer, got {target_token}."
)

def new_req_logits_processor(
self,
params: SamplingParams,
Expand Down
20 changes: 10 additions & 10 deletions tests/v1/logits_processors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ class CustomLogitprocSource(Enum):
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)

def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
Expand Down Expand Up @@ -91,16 +101,6 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:

return logits

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)


"""Dummy module with dummy logitproc class"""
dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE)
Expand Down
16 changes: 8 additions & 8 deletions vllm/model_executor/models/deepseek_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,6 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type"""

def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
super().__init__(vllm_config, device, is_pin_memory)

def is_argmax_invariant(self) -> bool:
return True

@classmethod
def validate_params(cls, params: SamplingParams):
ngram_size = params.extra_args and params.extra_args.get("ngram_size")
Expand Down Expand Up @@ -168,6 +160,14 @@ def validate_params(cls, params: SamplingParams):
f"got {whitelist_token_ids}."
)

def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
super().__init__(vllm_config, device, is_pin_memory)

def is_argmax_invariant(self) -> bool:
return True

def new_req_logits_processor(
self,
params: SamplingParams,
Expand Down
9 changes: 5 additions & 4 deletions vllm/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ def guard_cuda_initialization():
else:
err_msg = str(e)
raise RuntimeError(err_msg) from e
if had_key:
os.environ["CUDA_VISIBLE_DEVICES"] = old_value
else:
os.environ.pop("CUDA_VISIBLE_DEVICES")
finally:
if had_key:
os.environ["CUDA_VISIBLE_DEVICES"] = old_value
else:
os.environ.pop("CUDA_VISIBLE_DEVICES")


def get_dtype_size(dtype: torch.dtype) -> int:
Expand Down
16 changes: 8 additions & 8 deletions vllm/v1/sample/logits_processor/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ class BatchUpdate:


class LogitsProcessor(ABC):
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.

Raise ValueError for invalid ones.
"""
return None

@abstractmethod
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
Expand Down Expand Up @@ -96,11 +104,3 @@ def update_state(
to the batch makeup.
"""
raise NotImplementedError

@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.

Raise ValueError for invalid ones.
"""
return None