diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 9e3e82a78f92..05425d26800f 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -15,6 +15,7 @@ def run_server(server_args): asyncio.run(serve_grpc(server_args)) else: + # Default mode: HTTP mode. from sglang.srt.entrypoints.http_server import launch_server launch_server(server_args) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 787de125728c..555cfb101ec5 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -17,7 +17,6 @@ import copy import dataclasses import logging -import math import os import pickle import signal @@ -33,7 +32,6 @@ import fastapi import orjson -import torch import uvloop import zmq import zmq.asyncio @@ -78,6 +76,9 @@ from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin +from sglang.srt.managers.tokenizer_manager_multiitem_mixin import ( + TokenizerManagerMultiItemMixin, +) from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ( @@ -117,16 +118,6 @@ logger = logging.getLogger(__name__) -def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: - is_cross_node = server_args.dist_init_addr - - if is_cross_node: - # Fallback to default CPU transport for multi-node - return "default" - else: - return "cuda_ipc" - - @dataclasses.dataclass class ReqState: """Store the state a request.""" @@ -171,7 +162,15 @@ class ReqState: output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) -class TokenizerManager(TokenizerCommunicatorMixin): +class InputFormat(Enum): + """Input format types for tokenization handling.""" + + SINGLE_STRING = 1 # Regular single text like "Hello world" + BATCH_STRINGS = 2 # Regular batch like ["Hello", "World"] + CROSS_ENCODER_PAIRS = 3 # Cross-encoder pairs like [["query", "document"]] + + +class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixin): """TokenizerManager is a process that tokenizes the text.""" def __init__( @@ -219,29 +218,7 @@ def __init__( import_processors( envs.SGLANG_EXTERNAL_MM_PROCESSOR_PACKAGE.value, overwrite=True ) - try: - _processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - use_fast=not server_args.disable_fast_image_processor, - ) - except ValueError as e: - error_message = str(e) - if "does not have a slow version" in error_message: - logger.info( - f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version" - ) - _processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - use_fast=True, - ) - else: - raise e + _processor = _get_processor_wrapper(server_args) transport_mode = _determine_tensor_transport_mode(self.server_args) # We want to parallelize the image pre-processing so we create an executor for it @@ -436,89 +413,18 @@ async def generate_request( obj.normalize_batch_and_arguments() if self.enable_trace: - external_trace_header = None - if request: - if "trace_context" in request.headers: - trace_set_remote_propagate_context(request.headers["trace_context"]) - else: - external_trace_header = extract_trace_headers(request.headers) - - self._trace_request_start(obj, created_time, external_trace_header) - + self._trace_request_start(obj, created_time, request) if self.server_args.tokenizer_worker_num > 1: self._attach_multi_http_worker_info(obj) - if self.log_requests: - max_length, skip_names, _ = self.log_request_metadata - logger.info( - f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}" - ) - - # FIXME: This is a temporary fix to get the text from the input ids. - # We should remove this once we have a proper way. - if ( - self.log_requests_level >= 2 - and obj.text is None - and obj.input_ids is not None - and self.tokenizer is not None - ): - decoded = self.tokenizer.decode( - obj.input_ids, skip_special_tokens=False - ) - obj.text = decoded + self._log_received_request(obj) async with self.is_pause_cond: await self.is_pause_cond.wait_for(lambda: not self.is_pause) async with self.model_update_lock.reader_lock: if self.server_args.enable_lora and obj.lora_path: - if isinstance(obj.lora_path, str): - unique_lora_paths = set([obj.lora_path]) - else: - unique_lora_paths = set(obj.lora_path) - - if ( - self.server_args.max_loaded_loras is not None - and len(unique_lora_paths) > self.server_args.max_loaded_loras - ): - raise ValueError( - f"Received request with {len(unique_lora_paths)} unique loras requested " - f"but max loaded loras is {self.server_args.max_loaded_loras}" - ) - - # Reload all existing LoRA adapters that have been dynamically unloaded - unregistered_loras = await self.lora_registry.get_unregistered_loras( - unique_lora_paths - ) - for lora_path in unregistered_loras: - if lora_path is None: - continue - - if lora_path not in self.lora_ref_cache: - raise ValueError( - f"Got LoRA adapter that has never been loaded: {lora_path}\n" - f"All loaded adapters: {self.lora_ref_cache.keys()}." - ) - - logger.info(f"Reloading evicted adapter: {lora_path}") - new_lora_ref = self.lora_ref_cache[lora_path] - load_result = await self.load_lora_adapter( - LoadLoRAAdapterReqInput( - lora_name=new_lora_ref.lora_name, - lora_path=new_lora_ref.lora_path, - pinned=new_lora_ref.pinned, - ) - ) - if ( - not load_result.success - and "already loaded" not in load_result.error_message - ): - raise ValueError( - f"Failed to implicitly load LoRA adapter {lora_path}: {load_result.error_message}" - ) - - # Look up the LoRA ID from the registry and start tracking ongoing LoRA requests. - obj.lora_id = await self.lora_registry.acquire(obj.lora_path) + await self._resolve_lora_path(obj) if obj.is_single: tokenized_obj = await self._tokenize_one_request(obj) @@ -533,16 +439,16 @@ async def generate_request( def _detect_input_format( self, texts: Union[str, List[str]], is_cross_encoder: bool - ) -> str: + ) -> InputFormat: """Detect the format of input texts for proper tokenization handling. Returns: - - "single_string": Regular single text like "Hello world" - - "batch_strings": Regular batch like ["Hello", "World"] - - "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]] + - InputFormat.SINGLE_STRING: Regular single text like "Hello world" + - InputFormat.BATCH_STRINGS: Regular batch like ["Hello", "World"] + - InputFormat.CROSS_ENCODER_PAIRS: Cross-encoder pairs like [["query", "document"]] """ if isinstance(texts, str): - return "single_string" + return InputFormat.SINGLE_STRING if ( is_cross_encoder @@ -550,26 +456,26 @@ def _detect_input_format( and isinstance(texts[0], list) and len(texts[0]) == 2 ): - return "cross_encoder_pairs" + return InputFormat.CROSS_ENCODER_PAIRS - return "batch_strings" + return InputFormat.BATCH_STRINGS def _prepare_tokenizer_input( - self, texts: Union[str, List[str]], input_format: str + self, texts: Union[str, List[str]], input_format: InputFormat ) -> Union[List[str], List[List[str]]]: """Prepare input for the tokenizer based on detected format.""" - if input_format == "single_string": + if input_format == InputFormat.SINGLE_STRING: return [texts] # Wrap single string for batch processing - elif input_format == "cross_encoder_pairs": + elif input_format == InputFormat.CROSS_ENCODER_PAIRS: return texts # Already in correct format: [["query", "doc"]] - else: # batch_strings + else: # BATCH_STRINGS return texts # Already in correct format: ["text1", "text2"] def _extract_tokenizer_results( self, input_ids: List[List[int]], token_type_ids: Optional[List[List[int]]], - input_format: str, + input_format: InputFormat, original_batch_size: int, ) -> Union[ Tuple[List[int], Optional[List[int]]], @@ -579,7 +485,7 @@ def _extract_tokenizer_results( # For single inputs (string or single cross-encoder pair), extract first element if ( - input_format in ["single_string", "cross_encoder_pairs"] + input_format in [InputFormat.SINGLE_STRING, InputFormat.CROSS_ENCODER_PAIRS] and original_batch_size == 1 ): single_input_ids = input_ids[0] if input_ids else [] @@ -643,7 +549,7 @@ async def _tokenize_texts( # Step 3: Choose tokenization strategy use_async_tokenizer = ( self.async_dynamic_batch_tokenizer is not None - and input_format == "single_string" + and input_format == InputFormat.SINGLE_STRING ) if use_async_tokenizer: @@ -2088,50 +1994,6 @@ def _handle_update_weights_from_disk_req_output(self, recv_obj): if len(self.model_update_tmp) == self.server_args.dp_size: self.model_update_result.set_result(self.model_update_tmp) - def _initialize_multi_item_delimiter_text(self): - """Initialize multi-item delimiter text from token ID after tokenizer is loaded.""" - if ( - hasattr(self.server_args, "multi_item_scoring_delimiter") - and self.server_args.multi_item_scoring_delimiter is not None - and self.tokenizer is not None - ): - try: - self.multi_item_delimiter_text = self.tokenizer.decode( - [self.server_args.multi_item_scoring_delimiter], - skip_special_tokens=False, - ) - except Exception as e: - logger.warning( - f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}" - ) - self.multi_item_delimiter_text = None - - def _build_multi_item_token_sequence( - self, query: List[int], items: List[List[int]], delimiter_token_id: int - ) -> List[int]: - """ - Build a single token sequence for multi-item scoring. - Format: queryitem1item2item3 - - Args: - query: Query token IDs - items: List of item token ID sequences - delimiter_token_id: Token ID to use as delimiter - - Returns: - Combined token sequence - """ - combined_sequence = query[:] # Start with query - - for item in items: - combined_sequence.append(delimiter_token_id) # Add delimiter - combined_sequence.extend(item) # Add item tokens - - # Add final delimiter after the last item for logprob extraction - combined_sequence.append(delimiter_token_id) - - return combined_sequence - def _extract_logprobs_for_tokens( self, logprobs_data: List, label_token_ids: List[int] ) -> Dict[int, float]: @@ -2152,282 +2014,99 @@ def _extract_logprobs_for_tokens( logprobs[token_id] = logprob return logprobs - def _convert_logprobs_to_scores( - self, - logprobs: Dict[int, float], - label_token_ids: List[int], - apply_softmax: bool, - ) -> List[float]: - """ - Convert logprobs dictionary to ordered score list. - - Args: - logprobs: Dictionary mapping token_id to logprob - label_token_ids: Token IDs in desired order - apply_softmax: Whether to apply softmax normalization + async def watch_load_thread(self): + # Only for dp_controller when dp_size > 1 + if ( + self.server_args.dp_size == 1 + or self.server_args.load_balance_method == "round_robin" + ): + return - Returns: - List of scores in the same order as label_token_ids - """ - score_list = [ - logprobs.get(token_id, float("-inf")) for token_id in label_token_ids - ] + while True: + await asyncio.sleep(self.server_args.load_watch_interval) + loads = await self.get_load_communicator(GetLoadReqInput()) + load_udpate_req = WatchLoadUpdateReq(loads=loads) + self.send_to_scheduler.send_pyobj(load_udpate_req) - if apply_softmax: - score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist() + async def _resolve_lora_path(self, obj: Union[GenerateReqInput, EmbeddingReqInput]): + if isinstance(obj.lora_path, str): + unique_lora_paths = set([obj.lora_path]) else: - # Convert logprobs to probabilities if not using softmax - score_list = [ - math.exp(x) if x != float("-inf") else 0.0 for x in score_list - ] + unique_lora_paths = set(obj.lora_path) - return score_list - - def _process_multi_item_scoring_results( - self, - results: Any, - items: List, - label_token_ids: List[int], - apply_softmax: bool, - batch_request=None, - ) -> List[List[float]]: - """ - Process results from multi-item scoring request. - Extracts logprobs at delimiter positions from input_token_ids_logprobs. - - Args: - results: Results from generate_request - items: List of items being scored - label_token_ids: Token IDs to extract scores for - apply_softmax: Whether to apply softmax normalization - batch_request: The original batch request containing input sequence - - Returns: - List of score lists, one for each item - """ - single_result = results[0] if isinstance(results, list) else results - - # For multi-item scoring, logprobs are in input_token_ids_logprobs - input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", []) - - if not input_logprobs: - raise RuntimeError( - f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '')}. " - "This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring." - ) - - scores = [] - num_items = len(items) if isinstance(items, list) else 1 - - # Check if we have the expected number of logprobs - expected_logprobs_count = num_items + 1 - if len(input_logprobs) != expected_logprobs_count: - raise RuntimeError( - f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring " - f"with {num_items} items, but got {len(input_logprobs)}. " - f"Request ID: {single_result['meta_info'].get('id', '')}" - ) - - # Skip the first delimiter (between query and first item) and process remaining delimiter positions - # We want to exclude the first one since it represents the boundary between query and first item, not an item boundary - start_idx = 1 if len(input_logprobs) > 1 else 0 - - # Process logprobs for each item position (excluding first delimiter) - for item_idx in range(num_items): - logprob_idx = start_idx + item_idx - item_logprobs_data = input_logprobs[logprob_idx] - logprobs = self._extract_logprobs_for_tokens( - item_logprobs_data, label_token_ids - ) - score_list = self._convert_logprobs_to_scores( - logprobs, label_token_ids, apply_softmax - ) - scores.append(score_list) - - return scores - - def _process_single_item_scoring_results( - self, results: Any, label_token_ids: List[int], apply_softmax: bool - ) -> List[List[float]]: - """ - Process results from single-item scoring request. - Single-item scoring results are stored in output_token_ids_logprobs. - - Args: - results: Results from generate_request - label_token_ids: Token IDs to extract scores for - apply_softmax: Whether to apply softmax normalization - - Returns: - List of score lists, one for each result - """ - scores = [] - - for result in results: - # For single-item scoring, logprobs are in output_token_ids_logprobs - output_logprobs = result["meta_info"].get("output_token_ids_logprobs", []) - - if not output_logprobs or len(output_logprobs) == 0: - raise RuntimeError( - f"output_logprobs is empty for request {result['meta_info'].get('id', '')}." - ) - - # Extract logprobs for the first (and only) position - logprobs = self._extract_logprobs_for_tokens( - output_logprobs[0], label_token_ids - ) - score_list = self._convert_logprobs_to_scores( - logprobs, label_token_ids, apply_softmax + if ( + self.server_args.max_loaded_loras is not None + and len(unique_lora_paths) > self.server_args.max_loaded_loras + ): + raise ValueError( + f"Received request with {len(unique_lora_paths)} unique loras requested " + f"but max loaded loras is {self.server_args.max_loaded_loras}" ) - scores.append(score_list) - - return scores - - async def score_request( - self, - query: Optional[Union[str, List[int]]] = None, - items: Optional[Union[str, List[str], List[List[int]]]] = None, - label_token_ids: Optional[List[int]] = None, - apply_softmax: bool = False, - item_first: bool = False, - request: Optional[Any] = None, - ) -> List[List[float]]: - """ - Score the probability of specified token IDs appearing after the given (query + item) pair. - This method supports two scoring approaches: - 1. Single-Item scoring (default): Process each query+item pair independently - 2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and - multiple items into a single sequence using delimiter for efficient processing. - Note: item_first parameter is ignored in multi-item scoring mode since it uses - a fixed format: queryitem1item2item3 - - Multi-item scoring works with both text and pre-tokenized inputs: - - Text: queryitem1item2item3 - - Tokens: queryitem1item2item3 - - Args: - query: The query text or pre-tokenized query token IDs - items: The item text(s) or pre-tokenized item token IDs - label_token_ids: List of token IDs to compute probabilities for - apply_softmax: Whether to normalize probabilities using softmax - item_first: If True, prepend items to query. Ignored for multi-item scoring. - request: Optional FastAPI request object - - Returns: - List of lists containing probabilities for each item and each label token - """ - if label_token_ids is None: - raise ValueError("label_token_ids must be provided") - - if self.tokenizer is not None: - vocab_size = self.tokenizer.vocab_size - for token_id in label_token_ids: - if token_id >= vocab_size: - raise ValueError( - f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})" - ) - - # Check if multi-item scoring is enabled by presence of delimiter - use_multi_item_scoring = ( - self.server_args.multi_item_scoring_delimiter is not None - and self.multi_item_delimiter_text is not None + # Reload all existing LoRA adapters that have been dynamically unloaded + unregistered_loras = await self.lora_registry.get_unregistered_loras( + unique_lora_paths ) + for lora_path in unregistered_loras: + if lora_path is None: + continue - batch_request = GenerateReqInput( - token_ids_logprob=label_token_ids, - return_logprob=True, - # Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions - logprob_start_len=0 if use_multi_item_scoring else -1, - stream=False, - sampling_params={"max_new_tokens": 0}, - ) + if lora_path not in self.lora_ref_cache: + raise ValueError( + f"Got LoRA adapter that has never been loaded: {lora_path}\n" + f"All loaded adapters: {self.lora_ref_cache.keys()}." + ) - # Handle string or tokenized query/items - if isinstance(query, str) and ( - isinstance(items, str) - or (isinstance(items, list) and (not items or isinstance(items[0], str))) - ): - # Both query and items are text - items_list = [items] if isinstance(items, str) else items - - if use_multi_item_scoring: - # Multi-item scoring: create single prompt with delimiter text - # Always use format: queryitem1item2item3 - # (item_first is ignored for multi-item scoring) - delimiter = self.multi_item_delimiter_text - combined_items = delimiter.join(items_list) - # Add final delimiter after the last item for logprob extraction - single_prompt = f"{query}{delimiter}{combined_items}{delimiter}" - batch_request.text = [single_prompt] - else: - # Single-item scoring: create separate prompts for each item - if item_first: - prompts = [f"{item}{query}" for item in items_list] - else: - prompts = [f"{query}{item}" for item in items_list] - batch_request.text = prompts - - elif ( - isinstance(query, list) - and isinstance(items, list) - and items - and isinstance(items[0], list) - ): - # Both query and items are token IDs - if use_multi_item_scoring: - # Multi-item scoring: concatenate with delimiter token ID - # Format: queryitem1item2item3 - delimiter_token_id = self.server_args.multi_item_scoring_delimiter - combined_input_ids = self._build_multi_item_token_sequence( - query, items, delimiter_token_id + logger.info(f"Reloading evicted adapter: {lora_path}") + new_lora_ref = self.lora_ref_cache[lora_path] + load_result = await self.load_lora_adapter( + LoadLoRAAdapterReqInput( + lora_name=new_lora_ref.lora_name, + lora_path=new_lora_ref.lora_path, + pinned=new_lora_ref.pinned, ) - batch_request.input_ids = [combined_input_ids] - else: - # Single-item scoring: process each item separately - if item_first: - input_ids_list = [item + query for item in items] - else: - input_ids_list = [query + item for item in items] - batch_request.input_ids = input_ids_list - else: - raise ValueError( - "Invalid combination of query/items types for score_request." ) + if ( + not load_result.success + and "already loaded" not in load_result.error_message + ): + raise ValueError( + f"Failed to implicitly load LoRA adapter {lora_path}: {load_result.error_message}" + ) - results = await self.generate_request(batch_request, request).__anext__() + # Look up the LoRA ID from the registry and start tracking ongoing LoRA requests. + obj.lora_id = await self.lora_registry.acquire(obj.lora_path) - if use_multi_item_scoring: - # Multi-item scoring: extract scores from input_token_ids_logprobs - return self._process_multi_item_scoring_results( - results, items, label_token_ids, apply_softmax, batch_request - ) - else: - # Single-item scoring: process each result separately - return self._process_single_item_scoring_results( - results, label_token_ids, apply_softmax - ) + def _log_received_request(self, obj: Union[GenerateReqInput, EmbeddingReqInput]): + max_length, skip_names, _ = self.log_request_metadata + logger.info( + f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}" + ) - async def watch_load_thread(self): - # Only for dp_controller when dp_size > 1 + # FIXME: This is a temporary fix to get the text from the input ids. + # We should remove this once we have a proper way. if ( - self.server_args.dp_size == 1 - or self.server_args.load_balance_method == "round_robin" + self.log_requests_level >= 2 + and obj.text is None + and obj.input_ids is not None + and self.tokenizer is not None ): - return - - while True: - await asyncio.sleep(self.server_args.load_watch_interval) - loads = await self.get_load_communicator(GetLoadReqInput()) - load_udpate_req = WatchLoadUpdateReq(loads=loads) - self.send_to_scheduler.send_pyobj(load_udpate_req) + decoded = self.tokenizer.decode(obj.input_ids, skip_special_tokens=False) + obj.text = decoded def _trace_request_start( self, obj: Union[GenerateReqInput, EmbeddingReqInput], created_time: Optional[float] = None, - external_trace_header: Optional[Dict] = None, + request: Optional[fastapi.Request] = None, ): + external_trace_header = None + if request: + if "trace_context" in request.headers: + trace_set_remote_propagate_context(request.headers["trace_context"]) + else: + external_trace_header = extract_trace_headers(request.headers) + if obj.is_single: bootstrap_room = ( obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None @@ -2481,6 +2160,43 @@ async def print_exception_wrapper(func): sys.exit(1) +def _get_processor_wrapper(server_args): + try: + processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + use_fast=not server_args.disable_fast_image_processor, + ) + except ValueError as e: + error_message = str(e) + if "does not have a slow version" in error_message: + logger.info( + f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version" + ) + processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + use_fast=True, + ) + else: + raise e + return processor + + +def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: + is_cross_node = server_args.dist_init_addr + + if is_cross_node: + # Fallback to default CPU transport for multi-node + return "default" + else: + return "cuda_ipc" + + class SignalHandler: def __init__(self, tokenizer_manager: TokenizerManager): self.tokenizer_manager = tokenizer_manager diff --git a/python/sglang/srt/managers/tokenizer_manager_multiitem_mixin.py b/python/sglang/srt/managers/tokenizer_manager_multiitem_mixin.py new file mode 100644 index 000000000000..c6e76b65f6ba --- /dev/null +++ b/python/sglang/srt/managers/tokenizer_manager_multiitem_mixin.py @@ -0,0 +1,311 @@ +import logging +import math +from typing import Any, Dict, List, Optional, Union + +from sglang.srt.managers.io_struct import GenerateReqInput + +logger = logging.getLogger(__name__) + + +class TokenizerManagerMultiItemMixin: + def _initialize_multi_item_delimiter_text(self): + """Initialize multi-item delimiter text from token ID after tokenizer is loaded.""" + if ( + hasattr(self.server_args, "multi_item_scoring_delimiter") + and self.server_args.multi_item_scoring_delimiter is not None + and self.tokenizer is not None + ): + try: + self.multi_item_delimiter_text = self.tokenizer.decode( + [self.server_args.multi_item_scoring_delimiter], + skip_special_tokens=False, + ) + except Exception as e: + logger.warning( + f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}" + ) + self.multi_item_delimiter_text = None + + def _build_multi_item_token_sequence( + self, query: List[int], items: List[List[int]], delimiter_token_id: int + ) -> List[int]: + """ + Build a single token sequence for multi-item scoring. + Format: queryitem1item2item3 + + Args: + query: Query token IDs + items: List of item token ID sequences + delimiter_token_id: Token ID to use as delimiter + + Returns: + Combined token sequence + """ + combined_sequence = query[:] # Start with query + + for item in items: + combined_sequence.append(delimiter_token_id) # Add delimiter + combined_sequence.extend(item) # Add item tokens + + # Add final delimiter after the last item for logprob extraction + combined_sequence.append(delimiter_token_id) + + return combined_sequence + + def _process_multi_item_scoring_results( + self, + results: Any, + items: List, + label_token_ids: List[int], + apply_softmax: bool, + batch_request=None, + ) -> List[List[float]]: + """ + Process results from multi-item scoring request. + Extracts logprobs at delimiter positions from input_token_ids_logprobs. + + Args: + results: Results from generate_request + items: List of items being scored + label_token_ids: Token IDs to extract scores for + apply_softmax: Whether to apply softmax normalization + batch_request: The original batch request containing input sequence + + Returns: + List of score lists, one for each item + """ + single_result = results[0] if isinstance(results, list) else results + + # For multi-item scoring, logprobs are in input_token_ids_logprobs + input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", []) + + if not input_logprobs: + raise RuntimeError( + f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '')}. " + "This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring." + ) + + scores = [] + num_items = len(items) if isinstance(items, list) else 1 + + # Check if we have the expected number of logprobs + expected_logprobs_count = num_items + 1 + if len(input_logprobs) != expected_logprobs_count: + raise RuntimeError( + f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring " + f"with {num_items} items, but got {len(input_logprobs)}. " + f"Request ID: {single_result['meta_info'].get('id', '')}" + ) + + # Skip the first delimiter (between query and first item) and process remaining delimiter positions + # We want to exclude the first one since it represents the boundary between query and first item, not an item boundary + start_idx = 1 if len(input_logprobs) > 1 else 0 + + # Process logprobs for each item position (excluding first delimiter) + for item_idx in range(num_items): + logprob_idx = start_idx + item_idx + item_logprobs_data = input_logprobs[logprob_idx] + logprobs = self._extract_logprobs_for_tokens( + item_logprobs_data, label_token_ids + ) + score_list = self._convert_logprobs_to_scores( + logprobs, label_token_ids, apply_softmax + ) + scores.append(score_list) + + return scores + + def _process_single_item_scoring_results( + self, results: Any, label_token_ids: List[int], apply_softmax: bool + ) -> List[List[float]]: + """ + Process results from single-item scoring request. + Single-item scoring results are stored in output_token_ids_logprobs. + + Args: + results: Results from generate_request + label_token_ids: Token IDs to extract scores for + apply_softmax: Whether to apply softmax normalization + + Returns: + List of score lists, one for each result + """ + scores = [] + + for result in results: + # For single-item scoring, logprobs are in output_token_ids_logprobs + output_logprobs = result["meta_info"].get("output_token_ids_logprobs", []) + + if not output_logprobs or len(output_logprobs) == 0: + raise RuntimeError( + f"output_logprobs is empty for request {result['meta_info'].get('id', '')}." + ) + + # Extract logprobs for the first (and only) position + logprobs = self._extract_logprobs_for_tokens( + output_logprobs[0], label_token_ids + ) + score_list = self._convert_logprobs_to_scores( + logprobs, label_token_ids, apply_softmax + ) + scores.append(score_list) + + return scores + + async def score_request( + self, + query: Optional[Union[str, List[int]]] = None, + items: Optional[Union[str, List[str], List[List[int]]]] = None, + label_token_ids: Optional[List[int]] = None, + apply_softmax: bool = False, + item_first: bool = False, + request: Optional[Any] = None, + ) -> List[List[float]]: + """ + Score the probability of specified token IDs appearing after the given (query + item) pair. + + This method supports two scoring approaches: + 1. Single-Item scoring (default): Process each query+item pair independently + 2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and + multiple items into a single sequence using delimiter for efficient processing. + Note: item_first parameter is ignored in multi-item scoring mode since it uses + a fixed format: queryitem1item2item3 + + Multi-item scoring works with both text and pre-tokenized inputs: + - Text: queryitem1item2item3 + - Tokens: queryitem1item2item3 + + Args: + query: The query text or pre-tokenized query token IDs + items: The item text(s) or pre-tokenized item token IDs + label_token_ids: List of token IDs to compute probabilities for + apply_softmax: Whether to normalize probabilities using softmax + item_first: If True, prepend items to query. Ignored for multi-item scoring. + request: Optional FastAPI request object + + Returns: + List of lists containing probabilities for each item and each label token + """ + if label_token_ids is None: + raise ValueError("label_token_ids must be provided") + + if self.tokenizer is not None: + vocab_size = self.tokenizer.vocab_size + for token_id in label_token_ids: + if token_id >= vocab_size: + raise ValueError( + f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})" + ) + + # Check if multi-item scoring is enabled by presence of delimiter + use_multi_item_scoring = ( + self.server_args.multi_item_scoring_delimiter is not None + and self.multi_item_delimiter_text is not None + ) + + batch_request = GenerateReqInput( + token_ids_logprob=label_token_ids, + return_logprob=True, + # Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions + logprob_start_len=0 if use_multi_item_scoring else -1, + stream=False, + sampling_params={"max_new_tokens": 0}, + ) + + # Handle string or tokenized query/items + if isinstance(query, str) and ( + isinstance(items, str) + or (isinstance(items, list) and (not items or isinstance(items[0], str))) + ): + # Both query and items are text + items_list = [items] if isinstance(items, str) else items + + if use_multi_item_scoring: + # Multi-item scoring: create single prompt with delimiter text + # Always use format: queryitem1item2item3 + # (item_first is ignored for multi-item scoring) + delimiter = self.multi_item_delimiter_text + combined_items = delimiter.join(items_list) + # Add final delimiter after the last item for logprob extraction + single_prompt = f"{query}{delimiter}{combined_items}{delimiter}" + batch_request.text = [single_prompt] + else: + # Single-item scoring: create separate prompts for each item + if item_first: + prompts = [f"{item}{query}" for item in items_list] + else: + prompts = [f"{query}{item}" for item in items_list] + batch_request.text = prompts + + elif ( + isinstance(query, list) + and isinstance(items, list) + and items + and isinstance(items[0], list) + ): + # Both query and items are token IDs + if use_multi_item_scoring: + # Multi-item scoring: concatenate with delimiter token ID + # Format: queryitem1item2item3 + delimiter_token_id = self.server_args.multi_item_scoring_delimiter + combined_input_ids = self._build_multi_item_token_sequence( + query, items, delimiter_token_id + ) + batch_request.input_ids = [combined_input_ids] + else: + # Single-item scoring: process each item separately + if item_first: + input_ids_list = [item + query for item in items] + else: + input_ids_list = [query + item for item in items] + batch_request.input_ids = input_ids_list + else: + raise ValueError( + "Invalid combination of query/items types for score_request." + ) + + results = await self.generate_request(batch_request, request).__anext__() + + if use_multi_item_scoring: + # Multi-item scoring: extract scores from input_token_ids_logprobs + return self._process_multi_item_scoring_results( + results, items, label_token_ids, apply_softmax, batch_request + ) + else: + # Single-item scoring: process each result separately + return self._process_single_item_scoring_results( + results, label_token_ids, apply_softmax + ) + + def _convert_logprobs_to_scores( + self, + logprobs: Dict[int, float], + label_token_ids: List[int], + apply_softmax: bool, + ) -> List[float]: + """ + Convert logprobs dictionary to ordered score list. + + Args: + logprobs: Dictionary mapping token_id to logprob + label_token_ids: Token IDs in desired order + apply_softmax: Whether to apply softmax normalization + + Returns: + List of scores in the same order as label_token_ids + """ + import torch + + score_list = [ + logprobs.get(token_id, float("-inf")) for token_id in label_token_ids + ] + + if apply_softmax: + score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist() + else: + # Convert logprobs to probabilities if not using softmax + score_list = [ + math.exp(x) if x != float("-inf") else 0.0 for x in score_list + ] + + return score_list