Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ runtime_common = [
"hf_transfer",
"huggingface_hub",
"interegular",
"llguidance>=0.6.15",
"llguidance>=0.7.11,<0.8.0",
"modelscope",
"ninja",
"orjson",
Expand Down
139 changes: 78 additions & 61 deletions python/sglang/srt/constrained/llguidance_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,48 @@
"""Constrained decoding with llguidance backend."""

import json
import logging
import os
from typing import List, Optional, Tuple

import llguidance
import llguidance.hf
import llguidance.torch
import torch
from llguidance.gbnf_to_lark import any_to_lark
from llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from
from llguidance.hf import from_tokenizer
from llguidance.torch import (
allocate_token_bitmask,
apply_token_bitmask_inplace,
fill_next_token_bitmask,
)

from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend,
BaseGrammarObject,
)

logger = logging.getLogger(__name__)


class GuidanceGrammar(BaseGrammarObject):
def __init__(
self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
):

def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str):
super().__init__()
self.llguidance_tokenizer = llguidance_tokenizer
self.serialized_grammar = serialized_grammar

# TODO: add support for fast-forward tokens in the future
self.ll_interpreter = llguidance.LLInterpreter(
self.ll_matcher = LLMatcher(
self.llguidance_tokenizer,
self.serialized_grammar,
enable_backtrack=False,
enable_ff_tokens=False,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self.pending_ff_tokens: list[int] = []
self.finished = False
self.bitmask = None

def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
if len(self.pending_ff_tokens) > 0:
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
ff_tokens = self.pending_ff_tokens
self.pending_ff_tokens = []
return (ff_tokens, s)

return None
ff_tokens = self.ll_matcher.compute_ff_tokens()
if ff_tokens:
return ff_tokens, ""
else:
return None

def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
return "", -1
Expand All @@ -67,32 +66,22 @@ def jump_and_retokenize(
pass

def accept_token(self, token: int):
backtrack, ff_tokens = self.ll_interpreter.commit_token(token)
if len(ff_tokens) > 0 and backtrack == 0:
# first token is last generated token
ff_tokens = ff_tokens[1:]
self.pending_ff_tokens.extend(ff_tokens)
if not self.ll_matcher.consume_token(token):
logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
self.finished = True

def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
if len(self.pending_ff_tokens) > 0:
# if we have pending fast-forward tokens,
# just return them immediately
ff_token = self.pending_ff_tokens.pop(0)
vocab_mask[idx, :] = 0
vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
return

if self.ll_interpreter.has_pending_stop():
if self.ll_matcher.is_stopped():
self.finished = True

llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx)
fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx)

def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
if self.bitmask is None or self.bitmask.shape[0] < batch_size:
# only create bitmask when batch gets larger
self.bitmask = llguidance.torch.allocate_token_bitmask(
self.bitmask = allocate_token_bitmask(
batch_size, self.llguidance_tokenizer.vocab_size
)
bitmask = self.bitmask
Expand All @@ -107,7 +96,7 @@ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:

@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask)
apply_token_bitmask_inplace(logits, vocab_mask)

def copy(self):
return GuidanceGrammar(
Expand All @@ -117,36 +106,64 @@ def copy(self):


class GuidanceBackend(BaseGrammarBackend):
def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None):

def __init__(
self,
tokenizer,
whitespace_pattern: Optional[str] = None,
n_vocab: Optional[int] = None,
):
super().__init__()

self.tokenizer = tokenizer
self.whitespace_flexible = (
True if whitespace_pattern == "whitespace_flexible" else False
)
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)

def _from_serialized(self, serialized_grammar) -> GuidanceGrammar:
return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=serialized_grammar,
self.whitespace_pattern = whitespace_pattern
self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)

def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]:
try:
return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=serialized_grammar,
)
except Exception as e:
logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}")
return None

def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
serialized_grammar = LLMatcher.grammar_from_json_schema(
key_string,
defaults={
"whitespace_pattern": self.whitespace_pattern,
},
)

def dispatch_json(self, key_string: str) -> GuidanceGrammar:
json_schema = key_string
compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
serialized_grammar = compiler.compile(json_schema)
return self._from_serialized(serialized_grammar)

def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
compiler = llguidance.RegexCompiler()
serialized_grammar = compiler.compile(regex=key_string)
return self._from_serialized(serialized_grammar)

def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar:
compiler = llguidance.LarkCompiler()
serialized_grammar = compiler.compile(any_to_lark(key_string))
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
serialized_grammar = grammar_from("regex", key_string)
return self._from_serialized(serialized_grammar)

def dispatch_structural_tag(self, key_string: str):
return super().dispatch_structural_tag(key_string)
def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
serialized_grammar = grammar_from("ebnf", key_string)
return self._from_serialized(serialized_grammar)
except ValueError as e:
logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}")
return None

def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
structural_tag = json.loads(key_string)
tags = [
StructTag(
begin=structure["begin"],
grammar=structure["schema"],
end=structure["end"],
trigger=structural_tag["triggers"][0], # TODO?
)
for structure in structural_tag["structures"]
]
g = StructTag.to_grammar(tags)
return self._from_serialized(g)
except Exception as e:
logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}")
return None
6 changes: 6 additions & 0 deletions test/srt/test_ebnf_constrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,5 +238,11 @@ def test_ebnf_generate_custom_log_format(self):
)


class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=False)


if __name__ == "__main__":
unittest.main()
Loading