diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a3c5001e87a1..46345efdad54 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -479,6 +479,7 @@ def init_request_dispatcher(self): self.sampling_params_class = SamplingParams self.signal_handler_class = SignalHandler + self.req_state_class = ReqState async def generate_request( self, @@ -1056,7 +1057,9 @@ def _send_one_request( trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid) tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid) self.send_to_scheduler.send_pyobj(tokenized_obj) - state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) + state = self.req_state_class( + [], False, asyncio.Event(), obj, created_time=created_time + ) state.request_sent_to_scheduler_ts = time.time() self.rid_to_state[obj.rid] = state trace_slice_end( @@ -1082,7 +1085,7 @@ def _send_batch_request( # Create states for each individual request in the batch for i, tokenized_obj in enumerate(tokenized_objs): tmp_obj = obj[i] - state = ReqState( + state = self.req_state_class( [], False, asyncio.Event(), tmp_obj, created_time=created_time ) self.rid_to_state[tmp_obj.rid] = state