@@ -150,7 +150,9 @@ class ReqState:
150150
151151 # For streaming output
152152 last_output_offset : int = 0
153+
153154 # For incremental state update.
155+ # TODO(lianmin): do not initialize some lists if not needed.
154156 text : str = ""
155157 output_ids : List [int ] = dataclasses .field (default_factory = list )
156158 input_token_logprobs_val : List [float ] = dataclasses .field (default_factory = list )
@@ -199,7 +201,6 @@ def __init__(
199201 self .model_path = server_args .model_path
200202 self .served_model_name = server_args .served_model_name
201203 self .model_config = ModelConfig .from_server_args (server_args )
202-
203204 self .is_generation = self .model_config .is_generation
204205 self .is_image_gen = self .model_config .is_image_gen
205206 self .context_len = self .model_config .context_len
@@ -251,19 +252,36 @@ def __init__(
251252 self .dump_requests_threshold = 1000
252253 self .dump_request_list : List [Tuple ] = []
253254 self .log_request_metadata = self .get_log_request_metadata ()
255+ self .asyncio_tasks = set ()
256+ self .session_futures = {} # session_id -> asyncio event
257+ self .max_req_input_len = None
254258
255259 # The event to notify the weight sync is finished.
256260 self .model_update_lock = RWLock ()
257261 self .model_update_result : Optional [Awaitable [UpdateWeightFromDiskReqOutput ]] = (
258262 None
259263 )
260- self .asyncio_tasks = set ()
261264
262- # For session info
263- self .session_futures = {} # session_id -> asyncio event
265+ # For pd disaggregtion
266+ self .disaggregation_mode = DisaggregationMode (
267+ self .server_args .disaggregation_mode
268+ )
269+ self .transfer_backend = TransferBackend (
270+ self .server_args .disaggregation_transfer_backend
271+ )
272+ # Start kv boostrap server on prefill
273+ if self .disaggregation_mode == DisaggregationMode .PREFILL :
274+ # only start bootstrap server on prefill tm
275+ kv_bootstrap_server_class = get_kv_class (
276+ self .transfer_backend , KVClassType .BOOTSTRAP_SERVER
277+ )
278+ self .bootstrap_server = kv_bootstrap_server_class (
279+ self .server_args .disaggregation_bootstrap_port
280+ )
264281
265- # Set after scheduler is initialized
266- self .max_req_input_len = None
282+ # For load balancing
283+ self .current_load = 0
284+ self .current_load_lock = asyncio .Lock ()
267285
268286 # Metrics
269287 if self .enable_metrics :
@@ -393,66 +411,29 @@ def __init__(
393411 ]
394412 )
395413
396- # For pd disaggregtion
397- self .disaggregation_mode = DisaggregationMode (
398- self .server_args .disaggregation_mode
399- )
400- self .transfer_backend = TransferBackend (
401- self .server_args .disaggregation_transfer_backend
402- )
403- # Start kv boostrap server on prefill
404- if self .disaggregation_mode == DisaggregationMode .PREFILL :
405- # only start bootstrap server on prefill tm
406- kv_bootstrap_server_class = get_kv_class (
407- self .transfer_backend , KVClassType .BOOTSTRAP_SERVER
408- )
409- self .bootstrap_server = kv_bootstrap_server_class (
410- self .server_args .disaggregation_bootstrap_port
411- )
412-
413- self .current_load = 0
414- self .current_load_lock = asyncio .Lock ()
415-
416414 async def generate_request (
417415 self ,
418416 obj : Union [GenerateReqInput , EmbeddingReqInput ],
419417 request : Optional [fastapi .Request ] = None ,
420418 ):
421419 created_time = time .time ()
422-
423420 self .auto_create_handle_loop ()
421+ obj .normalize_batch_and_arguments ()
424422
425423 if isinstance (obj , EmbeddingReqInput ) and self .is_generation :
426424 raise ValueError (
427425 "This model does not appear to be an embedding model by default. "
428426 "Please add `--is-embedding` when launching the server or try another model."
429427 )
430428
431- obj .normalize_batch_and_arguments ()
432-
433- if isinstance (obj , GenerateReqInput ):
434- return_hidden_states = obj .return_hidden_states
435- has_return_hidden_states = return_hidden_states == True or (
436- isinstance (return_hidden_states , list ) and any (return_hidden_states )
437- )
438- if (
439- not self .server_args .enable_return_hidden_states
440- and has_return_hidden_states
441- ):
442- raise ValueError (
443- "return_hidden_states=True requires the server to be started "
444- "with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
445- )
446-
447429 if self .log_requests :
448430 max_length , skip_names , _ = self .log_request_metadata
449431 logger .info (
450432 f"Receive: obj={ dataclass_to_string_truncated (obj , max_length , skip_names = skip_names )} "
451433 )
452434
453435 async with self .model_update_lock .reader_lock :
454- is_single = obj .is_single
455- if is_single :
436+ if obj .is_single :
456437 tokenized_obj = await self ._tokenize_one_request (obj )
457438 state = self ._send_one_request (obj , tokenized_obj , created_time )
458439 async for response in self ._wait_one_response (obj , state , request ):
@@ -514,12 +495,12 @@ async def _tokenize_one_request(
514495 else :
515496 image_inputs : Optional [Dict ] = None
516497
517- self ._validate_token_len (obj , input_ids )
498+ self ._validate_one_request (obj , input_ids )
518499 return self ._create_tokenized_object (
519500 obj , input_text , input_ids , input_embeds , image_inputs , token_type_ids
520501 )
521502
522- def _validate_token_len (
503+ def _validate_one_request (
523504 self , obj : Union [GenerateReqInput , EmbeddingReqInput ], input_ids : List [int ]
524505 ) -> None :
525506 """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
@@ -548,6 +529,24 @@ def _validate_token_len(
548529 )
549530 raise ValueError (error_msg )
550531
532+ if isinstance (obj , GenerateReqInput ):
533+ if (
534+ obj .return_hidden_states
535+ and not self .server_args .enable_return_hidden_states
536+ ):
537+ raise ValueError (
538+ "The server is not configured to return the hidden states. "
539+ "Please set `--enable-return-hidden-states` to enable this feature."
540+ )
541+ if (
542+ obj .custom_logit_processor
543+ and not self .server_args .enable_custom_logit_processor
544+ ):
545+ raise ValueError (
546+ "The server is not configured to enable custom logit processor. "
547+ "Please set `--enable-custom-logits-processor` to enable this feature."
548+ )
549+
551550 def _create_tokenized_object (
552551 self ,
553552 obj : Union [GenerateReqInput , EmbeddingReqInput ],
@@ -558,24 +557,6 @@ def _create_tokenized_object(
558557 token_type_ids : Optional [List [int ]] = None ,
559558 ) -> Union [TokenizedGenerateReqInput , TokenizedEmbeddingReqInput ]:
560559 """Create a tokenized request object from common parameters."""
561-
562- if self .is_generation :
563- return_logprob = obj .return_logprob
564- logprob_start_len = obj .logprob_start_len
565- top_logprobs_num = obj .top_logprobs_num
566- token_ids_logprob = obj .token_ids_logprob
567- session_params = (
568- SessionParams (** obj .session_params ) if obj .session_params else None
569- )
570- if (
571- obj .custom_logit_processor
572- and not self .server_args .enable_custom_logit_processor
573- ):
574- raise ValueError (
575- "The server is not configured to enable custom logit processor. "
576- "Please set `--enable-custom-logits-processor` to enable this feature."
577- )
578-
579560 # Parse sampling parameters
580561 # Note: if there are preferred sampling params, we use them if they are not
581562 # explicitly passed in sampling_params
@@ -589,16 +570,20 @@ def _create_tokenized_object(
589570
590571 # Build return object
591572 if isinstance (obj , GenerateReqInput ):
573+ session_params = (
574+ SessionParams (** obj .session_params ) if obj .session_params else None
575+ )
576+
592577 tokenized_obj = TokenizedGenerateReqInput (
593578 obj .rid ,
594579 input_text ,
595580 input_ids ,
596581 image_inputs ,
597582 sampling_params ,
598- return_logprob ,
599- logprob_start_len ,
600- top_logprobs_num ,
601- token_ids_logprob ,
583+ obj . return_logprob ,
584+ obj . logprob_start_len ,
585+ obj . top_logprobs_num ,
586+ obj . token_ids_logprob ,
602587 obj .stream ,
603588 bootstrap_host = obj .bootstrap_host ,
604589 bootstrap_port = obj .bootstrap_port ,
0 commit comments