Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/sglang/srt/constrained/base_grammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
raise NotImplementedError()

def copy(self) -> "BaseGrammarObject":
raise NotImplementedError()
return self

@property
def finished(self):
Expand Down Expand Up @@ -99,9 +99,12 @@ def jump_and_retokenize(
raise NotImplementedError()


INVALID_GRAMMAR_OBJ = BaseGrammarObject()


@dataclass
class CacheEntry:
value: Optional[BaseGrammarObject]
value: BaseGrammarObject
event: Event


Expand Down
17 changes: 9 additions & 8 deletions python/sglang/srt/constrained/llguidance_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)

from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
Expand Down Expand Up @@ -126,8 +127,8 @@ def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]:
serialized_grammar=serialized_grammar,
)
except Exception as e:
logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}")
return None
logger.error(f"Hit invalid grammar: {serialized_grammar=}, {e=}")
return INVALID_GRAMMAR_OBJ

def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
Expand All @@ -138,8 +139,8 @@ def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
},
)
except Exception as e:
logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
return None
logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_serialized(serialized_grammar)

def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
Expand All @@ -151,8 +152,8 @@ def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]:
serialized_grammar = grammar_from("ebnf", key_string)
return self._from_serialized(serialized_grammar)
except ValueError as e:
logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}")
return None
logger.error(f"Hit invalid ebnf: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ

def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
Expand All @@ -169,5 +170,5 @@ def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
g = StructTag.to_grammar(tags)
return self._from_serialized(g)
except Exception as e:
logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}")
return None
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
9 changes: 5 additions & 4 deletions python/sglang/srt/constrained/outlines_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pydantic import BaseModel

from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
Expand Down Expand Up @@ -151,8 +152,8 @@ def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]:
# outlines <= 0.0.46
guide = RegexGuide(regex, self.outlines_tokenizer)
except interegular.patterns.InvalidSyntax as e:
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
return None
logger.error(f"Hit invalid regex schema: {regex=}, {e=}")
return INVALID_GRAMMAR_OBJ

jump_forward_map = None
return OutlinesGrammar(guide, jump_forward_map)
Expand All @@ -170,8 +171,8 @@ def dispatch_json(self, key_string: str):
whitespace_pattern=self.whitespace_pattern,
)
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
return None
logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._compile_regex(regex)

def dispatch_regex(self, key_string: str):
Expand Down
36 changes: 18 additions & 18 deletions python/sglang/srt/constrained/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)

from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
Expand Down Expand Up @@ -152,10 +153,11 @@ def __init__(
):
super().__init__()

tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size
)
override_stop_tokens = None
if True:
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size
)
override_stop_tokens = None
Comment on lines +156 to +160
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if True: block appears to be redundant. Could it be removed to simplify the code?

        tokenizer_info = TokenizerInfo.from_huggingface(
            tokenizer, vocab_size=vocab_size
        )
        override_stop_tokens = None


self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size
Expand All @@ -178,25 +180,26 @@ def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
return None

except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)

def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)

def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_regex(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)

def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
Expand All @@ -213,13 +216,10 @@ def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
except RuntimeError as e:
logging.warning(
f"Skip invalid structural_tag: structural_tag={key_string}, {e=}"
)
return None
except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)

def reset(self):
if self.grammar_compiler:
self.grammar_compiler.clear_cache()
self.grammar_compiler.clear_cache()
Comment on lines 224 to +225
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if self.grammar_compiler: check was removed from the reset method. This implies self.grammar_compiler is guaranteed to be initialized. While this seems to be the case from the constructor, could you confirm if there are any scenarios where self.grammar_compiler might not be set before reset is called? If it's always set, this change is fine.

4 changes: 2 additions & 2 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ async def stream_results() -> AsyncIterator[bytes]:
) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
logger.error(f"Error: {e}")
logger.error(f"[http_server] Error: {e}")
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
Expand All @@ -274,7 +274,7 @@ async def stream_results() -> AsyncIterator[bytes]:
).__anext__()
return ret
except ValueError as e:
logger.error(f"Error: {e}")
logger.error(f"[http_server] Error: {e}")
return _create_error_response(e)


Expand Down
14 changes: 13 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import logging
import threading
from enum import Enum, auto
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union

import numpy as np
Expand All @@ -51,6 +52,7 @@
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
Expand All @@ -60,7 +62,7 @@
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
from sglang.srt.utils import flatten_nested_list, support_triton

if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
Expand Down Expand Up @@ -771,6 +773,16 @@ def log_time_stats(self):
logger.info(f"{prefix}: {self.time_stats}")
self.has_log_time_stats = True

def set_finish_with_abort(self, error_msg: str):
if get_tensor_model_parallel_rank() == 0:
logger.error(f"{error_msg}, {self.rid=}")
self.multimodal_inputs = None
self.grammar = None
self.origin_input_ids = [0] # set it to one token to skip the long prefill
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line self.origin_input_ids = [0] with the comment "set it to one token to skip the long prefill" is a bit of an implicit contract. While the comment explains the intent, it might be worth considering if there's a more explicit way to signal to the prefill logic that the request is aborted and should be skipped or handled minimally, rather than relying on a specific input ID pattern. However, if this is a well-established pattern in the codebase, it's acceptable.

self.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)

def __repr__(self):
return (
f"Req(rid={self.rid}, "
Expand Down
Loading
Loading