Skip to content
11 changes: 2 additions & 9 deletions components/src/dynamo/sglang/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,6 @@ async def parse_args(args: list[str]) -> Config:
server_args.served_model_name = parsed_args.served_model_name
server_args.enable_metrics = getattr(parsed_args, "enable_metrics", False)
server_args.log_level = getattr(parsed_args, "log_level", "info")
server_args.skip_tokenizer_init = True
server_args.kv_events_config = getattr(parsed_args, "kv_events_config", None)
server_args.tp_size = getattr(parsed_args, "tp_size", 1)
server_args.dp_size = getattr(parsed_args, "dp_size", 1)
Expand Down Expand Up @@ -389,15 +388,9 @@ async def parse_args(args: list[str]) -> Config:
FutureWarning,
stacklevel=2,
)
logging.info(
"Using SGLang's built in tokenizer. Setting skip_tokenizer_init to False"
)
server_args.skip_tokenizer_init = False
logging.info("Using SGLang's built in tokenizer")
else:
logging.info(
"Using dynamo's built in tokenizer. Setting skip_tokenizer_init to True"
)
server_args.skip_tokenizer_init = True
logging.info("Using dynamo's built in tokenizer")

# Derive use_kv_events from server_args.kv_events_config
# Check that kv_events_config exists AND publisher is not "null" ("zmq" or any future publishers)
Expand Down
4 changes: 2 additions & 2 deletions components/src/dynamo/sglang/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ async def _register_model_with_runtime_config(
"""
runtime_config = await _get_runtime_config(engine, server_args, dynamo_args)

if not server_args.skip_tokenizer_init:
if dynamo_args.use_sglang_tokenizer:
logging.warning(
"The skip-tokenizer-init flag was not set. Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available"
"Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available"
)
input_type = ModelInput.Text
# Only override output_type for chat models, not for embeddings
Expand Down
18 changes: 15 additions & 3 deletions components/src/dynamo/sglang/request_handlers/handler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import inspect
import json
import logging
import random
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -173,13 +174,13 @@ def __init__(
self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher
self.serving_mode = config.serving_mode
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
self.use_sglang_tokenizer = config.dynamo_args.use_sglang_tokenizer
self.enable_trace = config.server_args.enable_trace

if engine is not None:
self.input_param_manager = InputParamManager(
self.engine.tokenizer_manager.tokenizer
if not self.skip_tokenizer_init
if self.use_sglang_tokenizer
else None
)
self._engine_supports_priority = (
Expand Down Expand Up @@ -430,13 +431,24 @@ def cleanup(self) -> None:

def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]:
request_input = self.input_param_manager.get_input_param(
request, use_tokenizer=not self.skip_tokenizer_init
request, use_tokenizer=self.use_sglang_tokenizer
)

return {
"prompt" if isinstance(request_input, str) else "input_ids": request_input
}

@staticmethod
def _get_guided_decoding_params(
guided_decoding: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
"""Extract guided decoding params (e.g. json_schema) for SGLang sampling_params."""
if isinstance(guided_decoding, dict):
json_schema = guided_decoding.get("json")
if json_schema is not None:
return {"json_schema": json.dumps(json_schema)}
return {}

@staticmethod
def _generate_bootstrap_room() -> int:
"""Generate a unique bootstrap room ID for disaggregated serving.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
Returns:
Dict of sampling parameters for SGLang engine.
"""
if self.skip_tokenizer_init:
if not self.use_sglang_tokenizer:
# Token-based request format
sampling_opts = request.get("sampling_options", {})
stop_conditions = request.get("stop_conditions", {})
Expand All @@ -99,6 +99,9 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"top_k": sampling_opts.get("top_k"),
"max_new_tokens": stop_conditions.get("max_tokens"),
"ignore_eos": stop_conditions.get("ignore_eos"),
**self._get_guided_decoding_params(
sampling_opts.get("guided_decoding")
),
}
else:
# OpenAI request format
Expand All @@ -107,6 +110,7 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"top_p": request.get("top_p"),
"top_k": request.get("top_k"),
"max_new_tokens": request.get("max_tokens"),
**self._get_guided_decoding_params(request.get("guided_decoding")),
}

return {k: v for k, v in param_mapping.items() if v is not None}
Expand Down Expand Up @@ -171,7 +175,7 @@ async def generate(
**self._priority_kwargs(priority),
)

if self.skip_tokenizer_init:
if not self.use_sglang_tokenizer:
async for out in self._process_token_stream(decode, context):
yield out
else:
Expand Down Expand Up @@ -202,7 +206,7 @@ async def generate(
data_parallel_rank=dp_rank,
**self._priority_kwargs(priority),
)
if self.skip_tokenizer_init:
if not self.use_sglang_tokenizer:
async for out in self._process_token_stream(agg, context):
yield out
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def generate(
)

# Process stream output (token-based or text-based)
if self.skip_tokenizer_init:
if not self.use_sglang_tokenizer:
async for out in self._process_token_stream(async_gen, context):
yield out
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ async def generate(
"top_p": sampling_opts.get("top_p"),
"top_k": sampling_opts.get("top_k"),
"max_new_tokens": stop_conditions.get("max_tokens"),
**self._get_guided_decoding_params(
sampling_opts.get("guided_decoding")
),
}
sampling_params = {
k: v for k, v in sampling_params.items() if v is not None
Expand Down
Loading