Skip to content
78 changes: 78 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,84 @@ def save_remote_model(self, **kwargs):
def save_sharded_model(self, **kwargs):
self.collective_rpc("save_sharded_model", **kwargs)

def score(
self,
text_1: str,
text_2: Union[str, List[str]],
positive_token_id: int,
negative_token_id: int,
prepend: bool = False,
) -> Dict:
"""
Score a list of texts using decoder-only models using positive/negative token probabilities.

A score is computed for each pair of text_1 + text_2 by comparing the model's probability
of generating the positive token versus the negative token immediately after the prompt.
The score is computed as: prob(positive) / (prob(positive) + prob(negative)).

For example, given:
- text_1 = "My name is "
- text_2 = ["John", "you"]
- positive_token_id = token_id("John")
- negative_token_id = token_id("you")

The method will return higher scores for items where the model assigns higher probability
to the positive token compared to the negative token.

Args:
text_1: The prompt text to score against. Must not be empty.
text_2: The text(s) to score. Can be a single string or list of strings.
positive_token_id: The token ID to use for positive scoring. Must be in model's vocabulary.
negative_token_id: The token ID to use for negative scoring. Must be in model's vocabulary.
prepend: If True, prepends text_2 to text_1 (i.e. text_2 + text_1).
If False (default), appends text_2 to text_1 (i.e. text_1 + text_2).
No additional characters (like newlines) are inserted between the texts.

Returns:
Dict containing:
- scores: List[float] - Scores in range [0,1] where higher values indicate stronger
preference for the positive token. None if logprobs are not available.
- model_info: Dict with model name and context length
- usage: Dict with token usage statistics (prompt_tokens, completion_tokens, cached_tokens, total_tokens)

Raises:
ValueError: If text_1 is empty, token IDs are out of vocabulary, or logprobs are not available
for the specified tokens.
"""
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.score_request(
text_1=text_1,
text_2=text_2,
positive_token_id=positive_token_id,
negative_token_id=negative_token_id,
prepend=prepend,
request=None
)
)

async def async_score(
self,
text_1: str,
text_2: Union[str, List[str]],
positive_token_id: int,
negative_token_id: int,
prepend: bool = False,
) -> Dict:
"""
Asynchronous version of score method.

See score() for detailed documentation.
"""
return await self.tokenizer_manager.score_request(
text_1=text_1,
text_2=text_2,
positive_token_id=positive_token_id,
negative_token_id=negative_token_id,
prepend=prepend,
request=None
)


def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
v1_retrieve_batch,
v1_retrieve_file,
v1_retrieve_file_content,
v1_score,
)
from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.reasoning_parser import ReasoningParser
Expand Down Expand Up @@ -714,6 +715,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
return ORJSONResponse({"predictions": ret})


@app.post("/v1/score")
async def openai_v1_score(raw_request: Request):
"""OpenAI-compatible endpoint for the scoring API. See Engine.score() for detailed documentation."""
return await v1_score(_global_state.tokenizer_manager, raw_request)


def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
Expand Down
81 changes: 81 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import json
import logging
import math
import os
import pickle
import signal
Expand Down Expand Up @@ -1416,6 +1417,86 @@ def _handle_update_weights_from_disk_req_output(self, recv_obj):
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)

async def score_request(
self,
text_1: str,
text_2: Union[str, List[str]],
positive_token_id: int,
negative_token_id: int,
prepend: bool = False,
request: Optional[Any] = None,
) -> Dict:
"""
Internal implementation of the scoring API. See Engine.score() for detailed documentation.
"""
if isinstance(text_2, str):
text_2 = [text_2]

prompts = [f"{text}{text_1}" for text in text_2] if prepend else [f"{text_1}{text}" for text in text_2]

if self.tokenizer is not None:
vocab_size = self.tokenizer.vocab_size
if positive_token_id >= vocab_size:
raise ValueError(f"Positive token ID {positive_token_id} is out of vocabulary (vocab size: {vocab_size})")
if negative_token_id >= vocab_size:
raise ValueError(f"Negative token ID {negative_token_id} is out of vocabulary (vocab size: {vocab_size})")

batch_request = GenerateReqInput(
text=prompts,
return_logprob=True,
token_ids_logprob=[positive_token_id, negative_token_id],
stream=False,
sampling_params={"max_new_tokens": 1},
)

scores = []
total_prompt_tokens = 0
total_completion_tokens = 0
total_cached_tokens = 0

request_to_pass = request if hasattr(request, 'is_disconnected') else None
async for results in self.generate_request(batch_request, request_to_pass):
if not isinstance(results, list):
results = [results]

for result in results:
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
pos_logprob = None
neg_logprob = None

if output_logprobs and len(output_logprobs) > 0:
# Directly access logprobs for the target tokens
first_position_logprobs = output_logprobs[0]
logprob_dict = {token_id: logprob for logprob, token_id, _ in first_position_logprobs}
pos_logprob = logprob_dict.get(positive_token_id)
neg_logprob = logprob_dict.get(negative_token_id)

if pos_logprob is None or neg_logprob is None:
scores.append(None)
else:
pos_prob = math.exp(pos_logprob)
neg_prob = math.exp(neg_logprob)
score = pos_prob / (pos_prob + neg_prob)
scores.append(score)

total_prompt_tokens += result["meta_info"]["prompt_tokens"]
total_completion_tokens += result["meta_info"]["completion_tokens"]
total_cached_tokens += result["meta_info"]["cached_tokens"]

return {
"scores": scores,
"model_info": {
"model": self.served_model_name,
"context_length": self.context_len,
},
"usage": {
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"cached_tokens": total_cached_tokens,
"total_tokens": total_prompt_tokens + total_completion_tokens,
},
}


async def print_exception_wrapper(func):
"""
Expand Down
31 changes: 31 additions & 0 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
ToolCall,
TopLogprob,
UsageInfo,
ScoringRequest,
ScoringResponse,
)
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import convert_json_schema_to_str, get_exception_traceback
Expand Down Expand Up @@ -1203,6 +1205,7 @@ def v1_chat_generate_request(
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
request_ids = [req.rid for req in all_requests]

adapted_request = GenerateReqInput(
**prompt_kwargs,
Expand Down Expand Up @@ -1920,3 +1923,31 @@ def append_top_logprobs(top_logprobs):
append_top_logprobs(output_top_logprobs)

return ret_logprobs


async def v1_score(tokenizer_manager, raw_request):
try:
# Parse request
request_data = await raw_request.json()
request = ScoringRequest(**request_data)

# Use tokenizer_manager's score_request method directly
result = await tokenizer_manager.score_request(
text_1=request.text_1,
text_2=request.text_2,
positive_token_id=request.positive_token_id,
negative_token_id=request.negative_token_id,
prepend=request.prepend,
request=request,
)

response = ScoringResponse(
scores=result["scores"],
model=request.model,
usage=result["usage"]
)
return response

except Exception as e:
logger.error(f"Error in v1_score: {str(e)}")
return create_error_response(str(e))
16 changes: 16 additions & 0 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,19 @@ class EmbeddingResponse(BaseModel):
model: str
object: str = "list"
usage: Optional[UsageInfo] = None


class ScoringRequest(BaseModel):
text_1: str
text_2: List[str]
positive_token_id: int
negative_token_id: int
prepend: bool = False
model: str


class ScoringResponse(BaseModel):
scores: List[Optional[float]]
model: str
usage: UsageInfo
object: str = "scoring"
Loading
Loading