Skip to content

Commit b92a805

Browse files
authored
feat: add BaseLogitsProcessor core interface (#2613)
Signed-off-by: Bhuvan Agrawal <[email protected]>
1 parent 0a71aea commit b92a805

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Dynamo Logits Processing - Backend-agnostic logits processors.
6+
7+
This module provides the BaseLogitsProcessor protocol that can be used
8+
across different backend adapters (TRT-LLM, vLLM, SGLang).
9+
"""
10+
11+
from .base import BaseLogitsProcessor
12+
13+
__all__ = ["BaseLogitsProcessor"]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Base logits processor protocol for Dynamo.
6+
7+
This module defines the core BaseLogitsProcessor interface that all
8+
logits processors must implement.
9+
"""
10+
11+
from typing import Protocol, Sequence
12+
13+
import torch
14+
15+
16+
class BaseLogitsProcessor(Protocol):
17+
"""
18+
Protocol for logits processors in Dynamo.
19+
20+
All logits processors must implement this interface to be compatible
21+
with backend adapters (TRT-LLM, vLLM, SGLang).
22+
"""
23+
24+
def __call__(
25+
self,
26+
input_ids: Sequence[int],
27+
logits: torch.Tensor,
28+
) -> torch.Tensor:
29+
"""
30+
Process the logits for the next token prediction.
31+
32+
Args:
33+
input_ids: The input token IDs generated so far.
34+
logits: The raw logits for the next token. Shape: (vocab_size,)
35+
36+
Returns:
37+
A tensor with the same shape, dtype, and device as `logits`.
38+
"""
39+
...

0 commit comments

Comments
 (0)