diff --git a/components/src/dynamo/sglang/args.py b/components/src/dynamo/sglang/args.py index 931227edd266..2447be37277a 100644 --- a/components/src/dynamo/sglang/args.py +++ b/components/src/dynamo/sglang/args.py @@ -353,7 +353,6 @@ async def parse_args(args: list[str]) -> Config: server_args.served_model_name = parsed_args.served_model_name server_args.enable_metrics = getattr(parsed_args, "enable_metrics", False) server_args.log_level = getattr(parsed_args, "log_level", "info") - server_args.skip_tokenizer_init = True server_args.kv_events_config = getattr(parsed_args, "kv_events_config", None) server_args.tp_size = getattr(parsed_args, "tp_size", 1) server_args.dp_size = getattr(parsed_args, "dp_size", 1) @@ -389,15 +388,9 @@ async def parse_args(args: list[str]) -> Config: FutureWarning, stacklevel=2, ) - logging.info( - "Using SGLang's built in tokenizer. Setting skip_tokenizer_init to False" - ) - server_args.skip_tokenizer_init = False + logging.info("Using SGLang's built in tokenizer") else: - logging.info( - "Using dynamo's built in tokenizer. Setting skip_tokenizer_init to True" - ) - server_args.skip_tokenizer_init = True + logging.info("Using dynamo's built in tokenizer") # Derive use_kv_events from server_args.kv_events_config # Check that kv_events_config exists AND publisher is not "null" ("zmq" or any future publishers) diff --git a/components/src/dynamo/sglang/register.py b/components/src/dynamo/sglang/register.py index 4674a9e4ff2f..5f12d490d180 100644 --- a/components/src/dynamo/sglang/register.py +++ b/components/src/dynamo/sglang/register.py @@ -38,9 +38,9 @@ async def _register_model_with_runtime_config( """ runtime_config = await _get_runtime_config(engine, server_args, dynamo_args) - if not server_args.skip_tokenizer_init: + if dynamo_args.use_sglang_tokenizer: logging.warning( - "The skip-tokenizer-init flag was not set. Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available" + "Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available" ) input_type = ModelInput.Text # Only override output_type for chat models, not for embeddings diff --git a/components/src/dynamo/sglang/request_handlers/handler_base.py b/components/src/dynamo/sglang/request_handlers/handler_base.py index 399cb25ebf29..a6ef5fcd48db 100644 --- a/components/src/dynamo/sglang/request_handlers/handler_base.py +++ b/components/src/dynamo/sglang/request_handlers/handler_base.py @@ -3,6 +3,7 @@ import asyncio import inspect +import json import logging import random from abc import ABC, abstractmethod @@ -173,13 +174,13 @@ def __init__( self.metrics_publisher = publisher.metrics_publisher self.kv_publisher = publisher.kv_publisher self.serving_mode = config.serving_mode - self.skip_tokenizer_init = config.server_args.skip_tokenizer_init + self.use_sglang_tokenizer = config.dynamo_args.use_sglang_tokenizer self.enable_trace = config.server_args.enable_trace if engine is not None: self.input_param_manager = InputParamManager( self.engine.tokenizer_manager.tokenizer - if not self.skip_tokenizer_init + if self.use_sglang_tokenizer else None ) self._engine_supports_priority = ( @@ -430,13 +431,24 @@ def cleanup(self) -> None: def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]: request_input = self.input_param_manager.get_input_param( - request, use_tokenizer=not self.skip_tokenizer_init + request, use_tokenizer=self.use_sglang_tokenizer ) return { "prompt" if isinstance(request_input, str) else "input_ids": request_input } + @staticmethod + def _get_guided_decoding_params( + guided_decoding: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + """Extract guided decoding params (e.g. json_schema) for SGLang sampling_params.""" + if isinstance(guided_decoding, dict): + json_schema = guided_decoding.get("json") + if json_schema is not None: + return {"json_schema": json.dumps(json_schema)} + return {} + @staticmethod def _generate_bootstrap_room() -> int: """Generate a unique bootstrap room ID for disaggregated serving. diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index 3bbca2d4592b..f78a2e9f5e57 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -88,7 +88,7 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: Returns: Dict of sampling parameters for SGLang engine. """ - if self.skip_tokenizer_init: + if not self.use_sglang_tokenizer: # Token-based request format sampling_opts = request.get("sampling_options", {}) stop_conditions = request.get("stop_conditions", {}) @@ -99,6 +99,9 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: "top_k": sampling_opts.get("top_k"), "max_new_tokens": stop_conditions.get("max_tokens"), "ignore_eos": stop_conditions.get("ignore_eos"), + **self._get_guided_decoding_params( + sampling_opts.get("guided_decoding") + ), } else: # OpenAI request format @@ -107,6 +110,7 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: "top_p": request.get("top_p"), "top_k": request.get("top_k"), "max_new_tokens": request.get("max_tokens"), + **self._get_guided_decoding_params(request.get("guided_decoding")), } return {k: v for k, v in param_mapping.items() if v is not None} @@ -171,7 +175,7 @@ async def generate( **self._priority_kwargs(priority), ) - if self.skip_tokenizer_init: + if not self.use_sglang_tokenizer: async for out in self._process_token_stream(decode, context): yield out else: @@ -202,7 +206,7 @@ async def generate( data_parallel_rank=dp_rank, **self._priority_kwargs(priority), ) - if self.skip_tokenizer_init: + if not self.use_sglang_tokenizer: async for out in self._process_token_stream(agg, context): yield out else: diff --git a/components/src/dynamo/sglang/request_handlers/llm/diffusion_handler.py b/components/src/dynamo/sglang/request_handlers/llm/diffusion_handler.py index 75f239461831..2213a0971156 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/diffusion_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/diffusion_handler.py @@ -89,7 +89,7 @@ async def generate( ) # Process stream output (token-based or text-based) - if self.skip_tokenizer_init: + if not self.use_sglang_tokenizer: async for out in self._process_token_stream(async_gen, context): yield out else: diff --git a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py index a0b16c1d4a85..4ae344382a61 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py @@ -86,6 +86,9 @@ async def generate( "top_p": sampling_opts.get("top_p"), "top_k": sampling_opts.get("top_k"), "max_new_tokens": stop_conditions.get("max_tokens"), + **self._get_guided_decoding_params( + sampling_opts.get("guided_decoding") + ), } sampling_params = { k: v for k, v in sampling_params.items() if v is not None