From 08cfab88bd21b568ff513d4e8170a103b0b3fb39 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 9 Dec 2025 23:11:27 +0000 Subject: [PATCH] [Auto Sync] Update data_parallel_controller.py, detokenizer... (20251209) Co-authored-by: Lianmin Zheng --- .../srt/managers/data_parallel_controller.py | 19 +++++-- .../srt/managers/detokenizer_manager.py | 57 ++++++++++++------- 2 files changed, 51 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 4ed72a8c0b94..ec9ae1d3d4e1 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -21,7 +21,7 @@ import time from collections import deque from enum import Enum, auto -from typing import List, Optional +from typing import Callable, List, Optional import psutil import setproctitle @@ -119,14 +119,19 @@ def dispatch(self): class DataParallelController: """A controller that dispatches requests to multiple data parallel workers.""" - def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + run_scheduler_process_func: Callable, + ) -> None: # Parse args self.server_args = server_args self.port_args = port_args self.load_balance_method = LoadBalanceMethod.from_str( server_args.load_balance_method ) - self.run_scheduler_process = run_scheduler_process + self.run_scheduler_process_func = run_scheduler_process_func # For DP balance self.global_balance_id = 0 @@ -429,7 +434,7 @@ def launch_tensor_parallel_group( moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) with self.env_lock, maybe_reindex_device_id(gpu_id) as gpu_id: proc = mp.Process( - target=self.run_scheduler_process, + target=self.run_scheduler_process_func, args=( server_args, rank_port_args, @@ -511,7 +516,7 @@ def run_data_parallel_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer, - data_parallel_controller_class=DataParallelController, + run_scheduler_process_func: Callable = run_scheduler_process, ): setproctitle.setproctitle("sglang::data_parallel_controller") faulthandler.enable() @@ -529,7 +534,9 @@ def run_data_parallel_controller_process( trace_set_thread_info(thread_label) try: - controller = data_parallel_controller_class(server_args, port_args) + controller = DataParallelController( + server_args, port_args, run_scheduler_process_func + ) pipe_writer.send( { "status": "ready", diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 87922077e05e..715c9d1b7eaa 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -84,6 +84,7 @@ def __init__( context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) + # Init tokenizer if server_args.skip_tokenizer_init: self.tokenizer = None else: @@ -95,8 +96,11 @@ def __init__( ) self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) - self.is_dummy = server_args.load_format == "dummy" + self.is_dummy = False + self.is_tool_call_parser_gpt_oss = server_args.tool_call_parser == "gpt-oss" + self.disable_tokenizer_batch_decode = server_args.disable_tokenizer_batch_decode + # Init dispatcher self._request_dispatcher = TypeBasedDispatcher( [ (BatchEmbeddingOutput, self.handle_batch_embedding_out), @@ -106,9 +110,6 @@ def __init__( ] ) - self.is_tool_call_parser_gpt_oss = server_args.tool_call_parser == "gpt-oss" - self.disable_tokenizer_batch_decode = server_args.disable_tokenizer_batch_decode - def event_loop(self): """The event loop that handles requests""" while True: @@ -148,7 +149,7 @@ def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOutput): # If it is embedding model, no detokenization is needed. return recv_obj - def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): + def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): bs = len(recv_obj.rids) # Initialize decode status @@ -176,8 +177,31 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): ) surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) - # TODO(lmzheng): better handle skip_special_tokens/spaces_between_special_tokens per request - if self.disable_tokenizer_batch_decode: + # Decode token ids to strings + # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request + if not self.disable_tokenizer_batch_decode: + if not self.is_dummy: + # Run normal batch decode + surr_texts = self.tokenizer.batch_decode( + surr_ids, + skip_special_tokens=recv_obj.skip_special_tokens[0], + spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[ + 0 + ], + ) + read_texts = self.tokenizer.batch_decode( + read_ids, + skip_special_tokens=recv_obj.skip_special_tokens[0], + spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[ + 0 + ], + ) + else: + # If it is dummy weights, just return dummy strings to prevent potential detokenization edge cases + surr_texts = ["dog" for _ in surr_ids] + read_texts = ["cat" for _ in read_ids] + else: + # Do not use batch decode to prevent some detokenization edge cases (e.g., gpt-oss). surr_texts = [ self.tokenizer.decode( surr, skip_special_tokens=skip, spaces_between_special_tokens=space @@ -198,17 +222,6 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): recv_obj.spaces_between_special_tokens, ) ] - else: - surr_texts = self.tokenizer.batch_decode( - surr_ids, - skip_special_tokens=recv_obj.skip_special_tokens[0], - spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], - ) - read_texts = self.tokenizer.batch_decode( - read_ids, - skip_special_tokens=recv_obj.skip_special_tokens[0], - spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], - ) # Incremental decoding output_strs = [] @@ -247,6 +260,11 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): s.sent_offset = len(output_str) output_strs.append(incremental_output) + return output_strs + + def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): + output_strs = self._decode_batch_token_id_output(recv_obj) + return BatchStrOutput( rids=recv_obj.rids, http_worker_ipcs=recv_obj.http_worker_ipcs, @@ -306,6 +324,7 @@ def __setitem__(self, key, value): def run_detokenizer_process( server_args: ServerArgs, port_args: PortArgs, + detokenizer_manager_class=DetokenizerManager, ): kill_itself_when_parent_died() setproctitle.setproctitle("sglang::detokenizer") @@ -313,7 +332,7 @@ def run_detokenizer_process( parent_process = psutil.Process().parent() try: - manager = DetokenizerManager(server_args, port_args) + manager = detokenizer_manager_class(server_args, port_args) if server_args.tokenizer_worker_num > 1: manager.multi_http_worker_event_loop() else: