Skip to content
This repository was archived by the owner on Mar 21, 2026. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ message ClearCacheRequest {
/// Empty response
message ClearCacheResponse {}

message LogitsProcessorParameters {
// The name of the processor to apply
string name = 1;
// The parameters to pass to the processor
repeated string parameters = 2;
}

message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
Expand All @@ -67,6 +74,8 @@ message NextTokenChooserParameters {
float repetition_penalty = 7;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// Optional Logits Processors definitions
repeated LogitsProcessorParameters logits_processors = 9;
}

message StoppingCriteriaParameters {
Expand Down
1 change: 1 addition & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ impl Client {
seed: 0,
repetition_penalty: 1.2,
watermark: true,
logits_processors: vec![],
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
Expand Down
1 change: 1 addition & 0 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ impl Health {
seed: 0,
repetition_penalty: 1.0,
watermark: false,
logits_processors: vec![],
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
Expand Down
1 change: 1 addition & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ impl Validation {
do_sample,
seed,
watermark,
logits_processors: vec![],
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
Expand Down
11 changes: 10 additions & 1 deletion server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pathlib import Path
from loguru import logger
from typing import Optional
from typing import List, Optional
from enum import Enum
from huggingface_hub import hf_hub_download

Expand Down Expand Up @@ -38,6 +38,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
custom_modules: Optional[List[str]] = None,
):
if sharded:
assert (
Expand Down Expand Up @@ -65,6 +66,14 @@ def serve(
diagnose=False,
)

# Import custom modules. This can be used for Custom Logits Processors,
# in which these modules CustomLogitsProcessorsManager.register_factory() in their __init__.py,
# registering themselves for custom logits processing.
from importlib import import_module
if custom_modules:
for custom_module in custom_modules:
import_module(custom_module)

# Import here after the logger is added to log potential import exceptions
from text_generation_server import server
from text_generation_server.tracing import setup_tracing
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/models/galactica.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def from_pb(
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
Expand Down
29 changes: 29 additions & 0 deletions server/text_generation_server/utils/custom_logits_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Dict, Iterable
from transformers import PreTrainedTokenizerBase, LogitsWarper
import abc


class CustomLogitsWarperFactory(abc.ABC):
def __init__(self, name):
self.name = name

@abc.abstractmethod
def create_warper(self, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper:
raise NotImplementedError


class CustomLogitsProcessorsManager:
processors: Dict[str, CustomLogitsWarperFactory] = {}

@classmethod
def register_factory(cls, factory: CustomLogitsWarperFactory):
"""Register a factory for a custom warper. This should be called by library developers in the __init__.py of their custom warper module,
and the module should be included in the custom_modules command line argument of the text-generation-server CLI."""
cls.processors[factory.name] = factory

@classmethod
def create_warper(cls, name: str, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper:
"""Create a custom warper by name."""
if name not in cls.processors:
raise ValueError(f"Unknown warper {name}. Known warpers: {', '.join(cls.processors.keys())}")
return cls.processors[name].create_warper(tokenizer, params)
24 changes: 21 additions & 3 deletions server/text_generation_server/utils/tokens.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, NamedTuple, Optional, Tuple
from text_generation_server.utils.custom_logits_processors import CustomLogitsProcessorsManager

import torch
from text_generation_server.pb import generate_pb2
Expand All @@ -16,6 +17,9 @@
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor

class LogitsProcessorParams(NamedTuple):
name: str
params: List[str]

class NextTokenChooser:
def __init__(
Expand All @@ -29,6 +33,8 @@ def __init__(
do_sample=False,
seed=0,
device="cpu",
logits_processors_params: Optional[List[LogitsProcessorParams]] = None,
tokenizer: PreTrainedTokenizerBase=None,
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
Expand All @@ -51,7 +57,10 @@ def __init__(
)
else:
self.static_warper = None

if logits_processors_params:
self.custom_warpers = [CustomLogitsProcessorsManager.create_warper(name, params, tokenizer) for name, params in logits_processors_params]
else:
self.custom_warpers = None
sampling = do_sample or has_warpers
self.choice = Sampling(seed, device) if sampling else Greedy()

Expand All @@ -60,7 +69,9 @@ def __call__(self, input_ids, scores):
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)

if self.custom_warpers is not None:
for warper in self.custom_warpers:
scores = warper(input_ids, scores)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
else:
Expand All @@ -75,7 +86,12 @@ def from_pb(
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser":
if pb.logits_processors:
processors_params = [LogitsProcessorParams(name, params) for name, params in pb.logits_processors]
else:
processors_params = None
return NextTokenChooser(
watermark=pb.watermark,
temperature=pb.temperature,
Expand All @@ -86,6 +102,8 @@ def from_pb(
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
logits_processors_params=processors_params,
tokenizer=tokenizer,
)


Expand Down