Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions lib/bindings/python/src/dynamo/logits_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Dynamo Logits Processing - Backend-agnostic logits processors.

This module provides the BaseLogitsProcessor protocol that can be used
across different backend adapters (TRT-LLM, vLLM, SGLang).
"""

from .base import BaseLogitsProcessor

__all__ = ["BaseLogitsProcessor"]
39 changes: 39 additions & 0 deletions lib/bindings/python/src/dynamo/logits_processing/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Base logits processor protocol for Dynamo.

This module defines the core BaseLogitsProcessor interface that all
logits processors must implement.
"""

from typing import List, Protocol

import torch


class BaseLogitsProcessor(Protocol):
"""
Protocol for logits processors in Dynamo.

All logits processors must implement this interface to be compatible
with backend adapters (TRT-LLM, vLLM, SGLang).
"""

def __call__(
self,
input_ids: List[int],
logits: torch.Tensor,
) -> torch.Tensor:
"""
Process the logits for the next token prediction.

Args:
input_ids: The input token IDs generated so far.
logits: The raw logits for the next token. Shape: (vocab_size,)

Returns:
The modified logits tensor with same shape as input.
"""
raise NotImplementedError
Loading