diff --git a/proto/generate.proto b/proto/generate.proto index c873e6615ef..7dbbda4a326 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -50,6 +50,13 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} +message LogitsProcessorParameters { + // The name of the processor to apply + string name = 1; + // The parameters to pass to the processor + repeated string parameters = 2; +} + message NextTokenChooserParameters { /// exponential scaling output probability distribution float temperature = 1; @@ -67,6 +74,8 @@ message NextTokenChooserParameters { float repetition_penalty = 7; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; + /// Optional Logits Processors definitions + repeated LogitsProcessorParameters logits_processors = 9; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 341e70fd588..22f3f6632b3 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -125,6 +125,7 @@ impl Client { seed: 0, repetition_penalty: 1.2, watermark: true, + logits_processors: vec![], }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/src/health.rs b/router/src/health.rs index ab290fc16e0..3cf6861c1a3 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -44,6 +44,7 @@ impl Health { seed: 0, repetition_penalty: 1.0, watermark: false, + logits_processors: vec![], }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/validation.rs b/router/src/validation.rs index 1b47fc9749c..a021e569f0b 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -279,6 +279,7 @@ impl Validation { do_sample, seed, watermark, + logits_processors: vec![], }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 301acb6be02..c668b138c42 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -4,7 +4,7 @@ from pathlib import Path from loguru import logger -from typing import Optional +from typing import List, Optional from enum import Enum from huggingface_hub import hf_hub_download @@ -38,6 +38,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + custom_modules: Optional[List[str]] = None, ): if sharded: assert ( @@ -65,6 +66,14 @@ def serve( diagnose=False, ) + # Import custom modules. This can be used for Custom Logits Processors, + # in which these modules CustomLogitsProcessorsManager.register_factory() in their __init__.py, + # registering themselves for custom logits processing. + from importlib import import_module + if custom_modules: + for custom_module in custom_modules: + import_module(custom_module) + # Import here after the logger is added to log potential import exceptions from text_generation_server import server from text_generation_server.tracing import setup_tracing diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index b296c96ebf9..65f25728f78 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -92,7 +92,7 @@ def from_pb( requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/utils/custom_logits_processors.py b/server/text_generation_server/utils/custom_logits_processors.py new file mode 100644 index 00000000000..446803049a8 --- /dev/null +++ b/server/text_generation_server/utils/custom_logits_processors.py @@ -0,0 +1,29 @@ +from typing import Dict, Iterable +from transformers import PreTrainedTokenizerBase, LogitsWarper +import abc + + +class CustomLogitsWarperFactory(abc.ABC): + def __init__(self, name): + self.name = name + + @abc.abstractmethod + def create_warper(self, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper: + raise NotImplementedError + + +class CustomLogitsProcessorsManager: + processors: Dict[str, CustomLogitsWarperFactory] = {} + + @classmethod + def register_factory(cls, factory: CustomLogitsWarperFactory): + """Register a factory for a custom warper. This should be called by library developers in the __init__.py of their custom warper module, + and the module should be included in the custom_modules command line argument of the text-generation-server CLI.""" + cls.processors[factory.name] = factory + + @classmethod + def create_warper(cls, name: str, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper: + """Create a custom warper by name.""" + if name not in cls.processors: + raise ValueError(f"Unknown warper {name}. Known warpers: {', '.join(cls.processors.keys())}") + return cls.processors[name].create_warper(tokenizer, params) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 0ff07417145..bd368b8c000 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,5 +1,6 @@ import re -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, NamedTuple, Optional, Tuple +from text_generation_server.utils.custom_logits_processors import CustomLogitsProcessorsManager import torch from text_generation_server.pb import generate_pb2 @@ -16,6 +17,9 @@ from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor +class LogitsProcessorParams(NamedTuple): + name: str + params: List[str] class NextTokenChooser: def __init__( @@ -29,6 +33,8 @@ def __init__( do_sample=False, seed=0, device="cpu", + logits_processors_params: Optional[List[LogitsProcessorParams]] = None, + tokenizer: PreTrainedTokenizerBase=None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -51,7 +57,10 @@ def __init__( ) else: self.static_warper = None - + if logits_processors_params: + self.custom_warpers = [CustomLogitsProcessorsManager.create_warper(name, params, tokenizer) for name, params in logits_processors_params] + else: + self.custom_warpers = None sampling = do_sample or has_warpers self.choice = Sampling(seed, device) if sampling else Greedy() @@ -60,7 +69,9 @@ def __call__(self, input_ids, scores): scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) - + if self.custom_warpers is not None: + for warper in self.custom_warpers: + scores = warper(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) else: @@ -75,7 +86,12 @@ def from_pb( cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "NextTokenChooser": + if pb.logits_processors: + processors_params = [LogitsProcessorParams(name, params) for name, params in pb.logits_processors] + else: + processors_params = None return NextTokenChooser( watermark=pb.watermark, temperature=pb.temperature, @@ -86,6 +102,8 @@ def from_pb( do_sample=pb.do_sample, seed=pb.seed, device=device, + logits_processors_params=processors_params, + tokenizer=tokenizer, )