From fbdfa297e4b1c020fd64bff314af2daef1e3535a Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Sat, 14 Mar 2026 02:43:06 +0000 Subject: [PATCH 01/12] upd --- python/sglang/test/kits/logprob_kit.py | 463 +++++++++++++++ test/manual/test_logprobs.py | 4 +- .../logprob/test_original_logprobs.py | 532 ++++++++++++++++++ .../sampling/test_original_logprobs.py | 199 ------- .../spec/eagle/test_eagle_infer_b.py | 6 + .../spec/eagle/test_eagle_infer_beta.py | 8 +- .../eagle/test_eagle_logprob_correctness.py | 101 ++++ 7 files changed, 1110 insertions(+), 203 deletions(-) create mode 100644 python/sglang/test/kits/logprob_kit.py create mode 100644 test/registered/logprob/test_original_logprobs.py delete mode 100644 test/registered/sampling/test_original_logprobs.py create mode 100644 test/registered/spec/eagle/test_eagle_logprob_correctness.py diff --git a/python/sglang/test/kits/logprob_kit.py b/python/sglang/test/kits/logprob_kit.py new file mode 100644 index 000000000000..96dcca2b65b7 --- /dev/null +++ b/python/sglang/test/kits/logprob_kit.py @@ -0,0 +1,463 @@ +"""Reusable cross-mode log-probability test kit. + +Generates tokens from a target server, then scores the same token sequence +via prefill on a baseline server (or the same server). Compares all logprob +artifacts between the two to verify correctness. + +Verified artifacts: + - output_token_logprobs: per-token value comparison + - input_token_logprobs: per-token value comparison + - output_top_logprobs: top-k token IDs and values + - input_top_logprobs: top-k token IDs and values + - output_token_ids_logprobs: values for user-specified token IDs + - input_token_ids_logprobs: values for user-specified token IDs + - logprob_start_len: boundary correctness + - return_text_in_logprobs: structural validation + +Usage as a Mixin (recommended for eagle / spec-decoding tests): + + from sglang.test.kits.logprob_kit import LogprobCrossModeMixin + + class TestMySpecDecoding(EagleServerBase, LogprobCrossModeMixin): + logprob_decimal_places = 2 # override tolerance as needed + # test methods are inherited from the mixin + +Usage as standalone functions: + + from sglang.test.kits.logprob_kit import ( + run_logprob_cross_mode_check, + run_logprob_start_len_check, + ) + + class TestMyServer(CustomTestCase): + def test_logprobs(self): + run_logprob_cross_mode_check(self, self.base_url) +""" + +import numpy as np +import requests + +CROSS_MODE_PROMPTS = [ + "The capital of France is", + "Explain quantum computing in simple terms:", + "Today is a sunny day and I like", +] +DEFAULT_MAX_NEW_TOKENS = 32 +DEFAULT_TOP_LOGPROBS_NUM = 5 +DEFAULT_PROBE_TOKEN_IDS = [1, 2, 10, 100, 1000] +DEFAULT_DECIMAL_PLACES = 2 + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _generate_with_logprobs( + url, + prompt_or_ids, + max_new_tokens, + top_logprobs_num, + token_ids_logprob, + logprob_start_len, + return_text_in_logprobs=False, +): + if isinstance(prompt_or_ids, str): + payload = {"text": prompt_or_ids} + else: + payload = {"input_ids": prompt_or_ids} + + payload.update( + { + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + "ignore_eos": True, + }, + "return_logprob": True, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": logprob_start_len, + "return_text_in_logprobs": return_text_in_logprobs, + } + ) + if token_ids_logprob is not None: + payload["token_ids_logprob"] = token_ids_logprob + + response = requests.post(url + "/generate", json=payload) + assert response.status_code == 200, f"Server error: {response.text}" + return response.json() + + +def _compare_token_logprobs(test_case, target_lps, baseline_lps, places, tag): + """Compare per-token logprob values, skipping None entries.""" + test_case.assertEqual( + len(target_lps), + len(baseline_lps), + msg=f"[{tag}] length mismatch: {len(target_lps)} vs {len(baseline_lps)}", + ) + + diffs = [] + for i in range(len(target_lps)): + t_val, t_id = target_lps[i][0], target_lps[i][1] + b_val, b_id = baseline_lps[i][0], baseline_lps[i][1] + + if t_val is None or b_val is None: + continue + + test_case.assertEqual( + t_id, + b_id, + msg=f"[{tag}] token_id mismatch at pos {i}: {t_id} vs {b_id}", + ) + test_case.assertAlmostEqual( + t_val, + b_val, + places=places, + msg=f"[{tag}] pos {i}: target={t_val:.6f} vs baseline={b_val:.6f}", + ) + diffs.append(abs(t_val - b_val)) + + if diffs: + print(f"[{tag}] max|diff|={max(diffs):.6f} mean|diff|={np.mean(diffs):.6f}") + + +def _compare_top_logprobs(test_case, target_tops, baseline_tops, places, tag): + """Compare top-k logprobs: common token IDs and their values.""" + test_case.assertEqual( + len(target_tops), + len(baseline_tops), + msg=f"[{tag}] length mismatch", + ) + + for pos in range(len(target_tops)): + if target_tops[pos] is None or baseline_tops[pos] is None: + continue + dec_top = {t[1]: t[0] for t in target_tops[pos] if t[0] is not None} + scr_top = {t[1]: t[0] for t in baseline_tops[pos] if t[0] is not None} + common_ids = set(dec_top.keys()) & set(scr_top.keys()) + test_case.assertGreater( + len(common_ids), + 0, + msg=f"[{tag}] pos {pos}: no common top-k IDs", + ) + for tid in common_ids: + test_case.assertAlmostEqual( + dec_top[tid], + scr_top[tid], + places=places, + msg=f"[{tag}] pos {pos} tid={tid}: " + f"target={dec_top[tid]:.6f} vs baseline={scr_top[tid]:.6f}", + ) + + +def _compare_ids_logprobs(test_case, target_ids, baseline_ids, places, tag): + """Compare token_ids_logprob values for user-specified token IDs.""" + test_case.assertEqual( + len(target_ids), + len(baseline_ids), + msg=f"[{tag}] length mismatch", + ) + + for pos in range(len(target_ids)): + if target_ids[pos] is None or baseline_ids[pos] is None: + continue + dec_map = {t[1]: t[0] for t in target_ids[pos]} + scr_map = {t[1]: t[0] for t in baseline_ids[pos]} + test_case.assertEqual( + set(dec_map.keys()), + set(scr_map.keys()), + msg=f"[{tag}] pos {pos}: token IDs differ", + ) + for tid in dec_map: + test_case.assertAlmostEqual( + dec_map[tid], + scr_map[tid], + places=places, + msg=f"[{tag}] pos {pos} tid={tid}: " + f"target={dec_map[tid]:.6f} vs baseline={scr_map[tid]:.6f}", + ) + + +# --------------------------------------------------------------------------- +# Public API – standalone functions +# --------------------------------------------------------------------------- + + +def run_logprob_cross_mode_check( + test_case, + target_url, + baseline_url=None, + prompts=None, + max_new_tokens=DEFAULT_MAX_NEW_TOKENS, + top_logprobs_num=DEFAULT_TOP_LOGPROBS_NUM, + token_ids_logprob=None, + return_text_in_logprobs=False, + decimal_places=DEFAULT_DECIMAL_PLACES, +): + """Run cross-mode logprob comparison across all artifacts. + + 1. Generate tokens from the target server (with logprob_start_len=0). + 2. Score the full sequence via prefill on the baseline server. + 3. Compare output_token_logprobs, input_token_logprobs, output_top_logprobs, + input_top_logprobs, output_token_ids_logprobs, input_token_ids_logprobs. + + When *baseline_url* is ``None``, the target server itself is used for the + scoring step. Because scoring uses ``max_new_tokens=0`` (pure prefill), + speculative decoding is not involved, making the target server a valid + non-speculative baseline for its own decode logprobs. + + Args: + test_case: ``unittest.TestCase`` instance for assertions. + target_url: URL of the target server (e.g. speculative decoding). + baseline_url: URL of the baseline server. Defaults to *target_url*. + prompts: List of prompt strings. + max_new_tokens: Tokens to generate per prompt. + top_logprobs_num: Top-k count for top_logprobs. + token_ids_logprob: Token IDs for the token_ids_logprob artifact. + return_text_in_logprobs: Whether to include token text. + decimal_places: ``assertAlmostEqual`` precision (``places``). + """ + if baseline_url is None: + baseline_url = target_url + if prompts is None: + prompts = list(CROSS_MODE_PROMPTS) + if token_ids_logprob is None: + token_ids_logprob = list(DEFAULT_PROBE_TOKEN_IDS) + + for round_idx, prompt in enumerate(prompts): + tag_prefix = f"round {round_idx}" + print(f"\n--- Cross-mode check {tag_prefix}: {prompt!r} ---") + + # Step 1: generate from target with logprob_start_len=0 + gen_res = _generate_with_logprobs( + target_url, + prompt, + max_new_tokens, + top_logprobs_num, + token_ids_logprob, + logprob_start_len=0, + return_text_in_logprobs=return_text_in_logprobs, + ) + meta = gen_res["meta_info"] + P = meta["prompt_tokens"] + + input_token_ids = [t[1] for t in meta["input_token_logprobs"]] + output_token_ids = [t[1] for t in meta["output_token_logprobs"]] + full_sequence = input_token_ids + output_token_ids + + # Step 2: score the full sequence via prefill on baseline + score_res = _generate_with_logprobs( + baseline_url, + full_sequence, + max_new_tokens=0, + top_logprobs_num=top_logprobs_num, + token_ids_logprob=token_ids_logprob, + logprob_start_len=0, + return_text_in_logprobs=return_text_in_logprobs, + ) + score_meta = score_res["meta_info"] + + # ----- output_token_logprobs ----- + _compare_token_logprobs( + test_case, + meta["output_token_logprobs"], + score_meta["input_token_logprobs"][P:], + decimal_places, + f"{tag_prefix} output_token_logprobs", + ) + + # ----- input_token_logprobs ----- + _compare_token_logprobs( + test_case, + meta["input_token_logprobs"], + score_meta["input_token_logprobs"][:P], + decimal_places, + f"{tag_prefix} input_token_logprobs", + ) + + # ----- output_top_logprobs ----- + if top_logprobs_num > 0: + _compare_top_logprobs( + test_case, + meta["output_top_logprobs"], + score_meta["input_top_logprobs"][P:], + decimal_places, + f"{tag_prefix} output_top_logprobs", + ) + + # ----- input_top_logprobs ----- + _compare_top_logprobs( + test_case, + meta["input_top_logprobs"], + score_meta["input_top_logprobs"][:P], + decimal_places, + f"{tag_prefix} input_top_logprobs", + ) + + # ----- output_token_ids_logprobs ----- + if token_ids_logprob: + _compare_ids_logprobs( + test_case, + meta["output_token_ids_logprobs"], + score_meta["input_token_ids_logprobs"][P:], + decimal_places, + f"{tag_prefix} output_token_ids_logprobs", + ) + + # ----- input_token_ids_logprobs ----- + target_in_ids = meta.get("input_token_ids_logprobs") + baseline_in_ids = score_meta.get("input_token_ids_logprobs") + if target_in_ids is not None and baseline_in_ids is not None: + _compare_ids_logprobs( + test_case, + target_in_ids, + baseline_in_ids[:P], + decimal_places, + f"{tag_prefix} input_token_ids_logprobs", + ) + + # ----- logprob_start_len boundary ----- + test_case.assertEqual( + len(meta["input_token_logprobs"]), + P, + msg=f"{tag_prefix}: input_token_logprobs length != prompt_tokens", + ) + + # ----- return_text_in_logprobs structural check ----- + if return_text_in_logprobs: + for lp in meta["output_token_logprobs"]: + test_case.assertIsNotNone( + lp[2], + msg=f"{tag_prefix}: output token text should not be None", + ) + + print(f"--- {tag_prefix} passed ---") + + +def run_logprob_start_len_check( + test_case, + target_url, + prompts=None, + max_new_tokens=8, + start_lens=None, +): + """Verify logprob_start_len boundary correctness on the target server. + + For each prompt and start_len, verifies that: + - ``len(input_token_logprobs) == prompt_tokens - logprob_start_len`` + - ``len(input_top_logprobs)`` matches + - ``len(output_token_logprobs) == max_new_tokens`` + """ + if prompts is None: + prompts = list(CROSS_MODE_PROMPTS) + + for prompt in prompts: + # Probe to get prompt_tokens + probe_res = _generate_with_logprobs( + target_url, + prompt, + max_new_tokens=1, + top_logprobs_num=0, + token_ids_logprob=None, + logprob_start_len=-1, + ) + P = probe_res["meta_info"]["prompt_tokens"] + + test_start_lens = ( + start_lens if start_lens is not None else [0, 1, P // 2, P - 1] + ) + + for sl in test_start_lens: + if sl >= P: + continue + with test_case.subTest(prompt=prompt, logprob_start_len=sl): + res = _generate_with_logprobs( + target_url, + prompt, + max_new_tokens, + top_logprobs_num=5, + token_ids_logprob=None, + logprob_start_len=sl, + ) + meta = res["meta_info"] + + expected_len = P - sl + test_case.assertEqual( + len(meta["input_token_logprobs"]), + expected_len, + msg=f"start_len={sl}: input_token_logprobs len", + ) + test_case.assertEqual( + len(meta["input_top_logprobs"]), + expected_len, + msg=f"start_len={sl}: input_top_logprobs len", + ) + test_case.assertEqual( + len(meta["output_token_logprobs"]), + max_new_tokens, + ) + test_case.assertEqual( + meta["prompt_tokens"], + sl + len(meta["input_token_logprobs"]), + ) + + print( + f"[logprob_start_len={sl} {prompt!r}] " + f"input={len(meta['input_token_logprobs'])}, " + f"output={len(meta['output_token_logprobs'])}" + ) + + +# --------------------------------------------------------------------------- +# Public API – Mixin class +# --------------------------------------------------------------------------- + + +class LogprobCrossModeMixin: + """Mixin providing cross-mode logprob test methods. + + Mix into a test class that has ``self.base_url`` pointing to the target + server. Override class attributes to customise behaviour. + + Example:: + + class TestEagleLogprobs(EagleServerBase, LogprobCrossModeMixin): + logprob_decimal_places = 2 + """ + + logprob_decimal_places = DEFAULT_DECIMAL_PLACES + logprob_max_new_tokens = DEFAULT_MAX_NEW_TOKENS + logprob_top_k = DEFAULT_TOP_LOGPROBS_NUM + logprob_prompts = None + logprob_probe_token_ids = None + + def test_cross_mode_logprobs(self): + """Compare decode logprobs against prefill scoring for all artifacts.""" + run_logprob_cross_mode_check( + self, + self.base_url, + prompts=self.logprob_prompts, + max_new_tokens=self.logprob_max_new_tokens, + top_logprobs_num=self.logprob_top_k, + token_ids_logprob=self.logprob_probe_token_ids, + decimal_places=self.logprob_decimal_places, + ) + + def test_cross_mode_logprob_start_len(self): + """Verify logprob_start_len boundary behaviour.""" + run_logprob_start_len_check( + self, + self.base_url, + prompts=self.logprob_prompts, + ) + + def test_cross_mode_return_text_in_logprobs(self): + """Verify return_text_in_logprobs structural correctness.""" + run_logprob_cross_mode_check( + self, + self.base_url, + prompts=(self.logprob_prompts or CROSS_MODE_PROMPTS)[:1], + max_new_tokens=8, + return_text_in_logprobs=True, + decimal_places=self.logprob_decimal_places, + ) diff --git a/test/manual/test_logprobs.py b/test/manual/test_logprobs.py index 5aa68c5ddf92..28b3e2723689 100644 --- a/test/manual/test_logprobs.py +++ b/test/manual/test_logprobs.py @@ -31,12 +31,12 @@ Step 1: Generate Baseline (Before Code Changes) ```bash -python test/srt/test_logprobs.py gen +python test/manual/test_logprobs.py gen ``` Step 2: Test Against Baseline (After Code Changes) ```bash -python test/srt/test_logprobs.py test +python test/manual/test_logprobs.py test ``` This tests your changes against the locally generated baseline from Step 1. The test passes if the maximum and mean differences are within the tolerance thresholds. diff --git a/test/registered/logprob/test_original_logprobs.py b/test/registered/logprob/test_original_logprobs.py new file mode 100644 index 000000000000..a3b5baaae88a --- /dev/null +++ b/test/registered/logprob/test_original_logprobs.py @@ -0,0 +1,532 @@ +"""Test original log probability alignment between SGLang and Hugging Face. + +This test suite verifies the correctness of the `origin_logprobs` output (temperature=1) +and the `logprobs` output (temperature=0.5) in SGLang by comparing it against +raw logit-based probabilities computed directly from a reference Hugging Face model. + +The test covers the following scenarios: +- Next-token prediction: Verifies that the log probability of the next token from + SGLang matches the Hugging Face model. +- Top-k logprobs: Ensures that the top-k original logprobs returned by SGLang are + consistent with Hugging Face outputs. +- Specified token IDs: Confirms that the original logprobs for specific token IDs + match the values computed from Hugging Face logits. +- Multi-token decoding: Validates per-step log-probability accuracy across a + complete generation sequence (max_new_tokens=8). +- Input logprobs: Verifies input_token_logprobs and input_top_logprobs against + the Hugging Face reference for input positions. +- logprob_start_len: Ensures correct boundary behavior for the starting index + of input log-probability computation. +""" + +import os +import random +import unittest + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer + +import sglang as sgl +from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + +register_cuda_ci(est_time=120, suite="stage-b-test-small-1-gpu") +register_amd_ci(est_time=180, suite="stage-b-test-small-1-gpu-amd") + +# ------------------------- Configurable via env ------------------------- # +MODEL_ID = DEFAULT_SMALL_MODEL_NAME_FOR_TEST +PROMPTS = [ + "Hello, my name is", + "The future of AI is", + "The president of the United States is", + "The capital of France is ", +] +TOP_LOGPROBS_NUM = 50 +NUM_RANDOM_TOKEN_IDS = 10 +RTOL = 0.20 +ATOL = 0.00 +MULTI_TOKEN_COUNT = 8 +# ------------------------------------------------ + +torch.manual_seed(1234) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(1234) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + +class TestOriginalLogprob(unittest.TestCase): + @classmethod + def setUpClass(cls): + # ----- HF side (float32 weights, loaded once for all tests) ----- + cls.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right") + cls.hf_model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype=torch.float32, device_map="auto" + ) + + def setUp(self): + # Shared sampling parameters + self.sampling_params = { + "temperature": 0.5, # SGLang uses 0.5, but original logprobs are used 1.0 + "top_p": 1.0, + "top_k": 10, + "max_new_tokens": 1, + } + + # --------------------------------------------------------------------- + # Helper: compare one SGLang block (token_logprobs / top_logprobs / ids_logprobs) + # against a reference HF log‑prob vector. + # --------------------------------------------------------------------- + def assert_logprobs_block_equal( + self, + hf_log_probs: torch.Tensor, # [V] + token_log_probs: list, + top_log_probs: list, + ids_log_probs: list, + random_token_ids: list, + tag: str = "", + ): + vals, idxs, _ = zip(*token_log_probs) + sgl_vals = torch.tensor(vals, device=self.hf_model.device, dtype=torch.float32) + sgl_idxs = torch.tensor(idxs, device=self.hf_model.device, dtype=torch.long) + hf_vals = hf_log_probs[sgl_idxs] + + self.assertTrue( + torch.allclose(hf_vals, sgl_vals, rtol=RTOL, atol=ATOL), + msg=f"[{tag}] token‑level mismatch at indices {sgl_idxs.tolist()}", + ) + + hf_topk, _ = torch.topk(hf_log_probs, k=TOP_LOGPROBS_NUM, dim=-1) + + sgl_topk = torch.tensor( + [float(t[0]) for t in top_log_probs[0] if t and t[0] is not None][ + :TOP_LOGPROBS_NUM + ], + dtype=torch.float32, + device=self.hf_model.device, + ) + + k = min(hf_topk.numel(), sgl_topk.numel()) + self.assertTrue( + torch.allclose(hf_topk[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL), + msg=f"[{tag}] top‑k mismatch", + ) + + indices = torch.tensor( + random_token_ids, dtype=torch.long, device=hf_log_probs.device + ) + + hf_token_ids = hf_log_probs[indices] + + sgl_token_ids = torch.tensor( + [v for v, _, _ in ids_log_probs[0]], + device=self.hf_model.device, + dtype=torch.float32, + ) + self.assertTrue( + torch.allclose(hf_token_ids, sgl_token_ids, rtol=RTOL, atol=ATOL), + msg=f"[{tag}] token‑IDs mismatch", + ) + + # Optional: print max abs diff for quick diagnostics + max_diff = torch.max(torch.abs(hf_vals - sgl_vals)).item() + print(f"[{tag}] max|diff| token‑level = {max_diff:.4f}") + + def test_logprob_match(self): + vocab_size = self.tokenizer.vocab_size + + for env_val in ["True", "False"]: + with self.subTest(SGLANG_RETURN_ORIGINAL_LOGPROB=env_val): + os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = env_val + + # ----- SGLang side ----- + sgl_engine = sgl.Engine( + model_path=MODEL_ID, + skip_tokenizer_init=True, + trust_remote_code=True, + mem_fraction_static=0.60, + ) + + for prompt in PROMPTS: + random_token_ids = sorted( + random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS) + ) + + enc = self.tokenizer(prompt, return_tensors="pt") + input_ids = enc["input_ids"].to(self.hf_model.device) + attn_mask = enc["attention_mask"].to(self.hf_model.device) + + with torch.inference_mode(): + hf_out = self.hf_model( + input_ids=input_ids, + attention_mask=attn_mask, + return_dict=True, + ) + logits = hf_out.logits[:, -1, :] # [1, V] + hf_log_probs = F.log_softmax( + logits.float() / self.sampling_params["temperature"], dim=-1 + )[0] + hf_original_log_probs = F.log_softmax(logits.float(), dim=-1)[0] + + outputs = sgl_engine.generate( + input_ids=input_ids[0].tolist(), + sampling_params=self.sampling_params, + return_logprob=True, + top_logprobs_num=TOP_LOGPROBS_NUM, + token_ids_logprob=random_token_ids, + ) + + if isinstance(outputs, list): + outputs = outputs[0] + meta = outputs["meta_info"] + + # Check original logprobs only if enabled + if env_val.lower() == "true": + self.assert_logprobs_block_equal( + hf_log_probs=hf_original_log_probs, + token_log_probs=meta["output_token_logprobs"], + top_log_probs=meta["output_top_logprobs"], + ids_log_probs=meta["output_token_ids_logprobs"], + random_token_ids=random_token_ids, + tag=f"Original logprobs SGLang vs HF: {prompt} ({env_val})", + ) + else: + # Always check regular logprobs + self.assert_logprobs_block_equal( + hf_log_probs=hf_log_probs, + token_log_probs=meta["output_token_logprobs"], + top_log_probs=meta["output_top_logprobs"], + ids_log_probs=meta["output_token_ids_logprobs"], + random_token_ids=random_token_ids, + tag=f"logprobs SGLang vs HF: {prompt} ({env_val})", + ) + sgl_engine.shutdown() + + # --------------------------------------------------------------------- + # Helper: compute HF per-position log-prob vectors for a token sequence. + # logprobs[i] = log_softmax(logits[i] / temperature) predicts token[i+1]. + # --------------------------------------------------------------------- + def hf_logprobs_for_sequence(self, token_ids, temperature=1.0): + input_tensor = torch.tensor( + [token_ids], dtype=torch.long, device=self.hf_model.device + ) + with torch.inference_mode(): + logits = self.hf_model(input_ids=input_tensor, return_dict=True).logits[0] + if temperature != 1.0: + logits = logits / temperature + return F.log_softmax(logits.float(), dim=-1) + + # --------------------------------------------------------------------- + # Helper: compare SGLang logprob outputs against HF at a single position. + # --------------------------------------------------------------------- + def assert_position_logprobs_match( + self, + hf_lp_vec, + sgl_token_logprob, + sgl_top_logprobs=None, + sgl_ids_logprobs=None, + random_token_ids=None, + tag="", + ): + sgl_val, sgl_idx = sgl_token_logprob[0], sgl_token_logprob[1] + if sgl_val is None: + return + + hf_val = hf_lp_vec[sgl_idx].item() + self.assertTrue( + torch.allclose( + torch.tensor([hf_val]), + torch.tensor([float(sgl_val)]), + rtol=RTOL, + atol=ATOL, + ), + msg=f"[{tag}] token logprob mismatch: HF={hf_val:.6f} vs SGL={sgl_val:.6f}", + ) + + if sgl_top_logprobs is not None: + hf_topk_vals, _ = torch.topk(hf_lp_vec, k=TOP_LOGPROBS_NUM, dim=-1) + sgl_vals = [float(t[0]) for t in sgl_top_logprobs if t and t[0] is not None] + if sgl_vals: + sgl_topk = torch.tensor( + sgl_vals[:TOP_LOGPROBS_NUM], + dtype=torch.float32, + device=self.hf_model.device, + ) + k = min(hf_topk_vals.numel(), sgl_topk.numel()) + self.assertTrue( + torch.allclose( + hf_topk_vals[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL + ), + msg=f"[{tag}] top-k mismatch", + ) + + if sgl_ids_logprobs is not None and random_token_ids: + ids_tensor = torch.tensor( + random_token_ids, dtype=torch.long, device=hf_lp_vec.device + ) + hf_ids_vals = hf_lp_vec[ids_tensor] + sgl_ids_vals = torch.tensor( + [float(v) for v, _, _ in sgl_ids_logprobs], + dtype=torch.float32, + device=self.hf_model.device, + ) + self.assertTrue( + torch.allclose(hf_ids_vals, sgl_ids_vals, rtol=RTOL, atol=ATOL), + msg=f"[{tag}] token-IDs mismatch", + ) + + def test_output_logprob_multi_token(self): + """Multi-token decode: validate per-step logprobs across an 8-token generation.""" + vocab_size = self.tokenizer.vocab_size + sampling_params = { + "temperature": 0.5, + "top_p": 1.0, + "top_k": 10, + "max_new_tokens": MULTI_TOKEN_COUNT, + } + + for env_val in ["True", "False"]: + with self.subTest(SGLANG_RETURN_ORIGINAL_LOGPROB=env_val): + os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = env_val + + sgl_engine = sgl.Engine( + model_path=MODEL_ID, + skip_tokenizer_init=True, + trust_remote_code=True, + mem_fraction_static=0.60, + ) + + for prompt in PROMPTS: + random_token_ids = sorted( + random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS) + ) + + enc = self.tokenizer(prompt, return_tensors="pt") + input_ids = enc["input_ids"][0].tolist() + prompt_len = len(input_ids) + + outputs = sgl_engine.generate( + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=True, + top_logprobs_num=TOP_LOGPROBS_NUM, + token_ids_logprob=random_token_ids, + ) + + if isinstance(outputs, list): + outputs = outputs[0] + meta = outputs["meta_info"] + + output_token_ids = [t[1] for t in meta["output_token_logprobs"]] + self.assertEqual(len(output_token_ids), MULTI_TOKEN_COUNT) + + full_sequence = input_ids + output_token_ids + hf_temp = ( + 1.0 if env_val == "True" else sampling_params["temperature"] + ) + hf_all_lp = self.hf_logprobs_for_sequence( + full_sequence, temperature=hf_temp + ) + + max_diff = 0.0 + for step in range(MULTI_TOKEN_COUNT): + hf_pos = prompt_len - 1 + step + self.assert_position_logprobs_match( + hf_lp_vec=hf_all_lp[hf_pos], + sgl_token_logprob=meta["output_token_logprobs"][step], + sgl_top_logprobs=meta["output_top_logprobs"][step], + sgl_ids_logprobs=meta["output_token_ids_logprobs"][step], + random_token_ids=random_token_ids, + tag=f"Multi-token step {step}: {prompt} ({env_val})", + ) + sgl_val = meta["output_token_logprobs"][step][0] + hf_val = hf_all_lp[hf_pos][ + meta["output_token_logprobs"][step][1] + ].item() + max_diff = max(max_diff, abs(hf_val - sgl_val)) + + print( + f"[Multi-token {prompt} ({env_val})] max|diff| = {max_diff:.4f}" + ) + + sgl_engine.shutdown() + + def test_input_logprobs(self): + """Input logprobs: verify input_token_logprobs and input_top_logprobs against HF.""" + vocab_size = self.tokenizer.vocab_size + sampling_params = { + "temperature": 0.5, + "top_p": 1.0, + "top_k": 10, + "max_new_tokens": 1, + } + + for env_val in ["True", "False"]: + with self.subTest(SGLANG_RETURN_ORIGINAL_LOGPROB=env_val): + os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = env_val + + sgl_engine = sgl.Engine( + model_path=MODEL_ID, + skip_tokenizer_init=True, + trust_remote_code=True, + mem_fraction_static=0.60, + ) + + for prompt in PROMPTS: + random_token_ids = sorted( + random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS) + ) + + enc = self.tokenizer(prompt, return_tensors="pt") + input_ids = enc["input_ids"][0].tolist() + prompt_len = len(input_ids) + + outputs = sgl_engine.generate( + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=True, + logprob_start_len=0, + top_logprobs_num=TOP_LOGPROBS_NUM, + token_ids_logprob=random_token_ids, + ) + + if isinstance(outputs, list): + outputs = outputs[0] + meta = outputs["meta_info"] + + input_token_logprobs = meta["input_token_logprobs"] + input_top_logprobs = meta["input_top_logprobs"] + input_ids_logprobs = meta.get("input_token_ids_logprobs") + + self.assertEqual(len(input_token_logprobs), prompt_len) + + hf_temp = ( + 1.0 if env_val == "True" else sampling_params["temperature"] + ) + hf_all_lp = self.hf_logprobs_for_sequence( + input_ids, temperature=hf_temp + ) + + # Skip position 0 (no preceding logits for the first token) + max_diff = 0.0 + for pos in range(1, prompt_len): + hf_lp_vec = hf_all_lp[pos - 1] + + sgl_top = ( + input_top_logprobs[pos] if input_top_logprobs else None + ) + sgl_ids = ( + input_ids_logprobs[pos] if input_ids_logprobs else None + ) + + self.assert_position_logprobs_match( + hf_lp_vec=hf_lp_vec, + sgl_token_logprob=input_token_logprobs[pos], + sgl_top_logprobs=sgl_top, + sgl_ids_logprobs=sgl_ids, + random_token_ids=random_token_ids, + tag=f"Input pos {pos}: {prompt} ({env_val})", + ) + + sgl_val = input_token_logprobs[pos][0] + if sgl_val is not None: + hf_val = hf_lp_vec[input_token_logprobs[pos][1]].item() + max_diff = max(max_diff, abs(hf_val - sgl_val)) + + print( + f"[Input logprobs {prompt} ({env_val})] " + f"max|diff| = {max_diff:.4f}" + ) + + sgl_engine.shutdown() + + def test_logprob_start_len(self): + """Verify logprob_start_len correctly controls the starting index.""" + sampling_params = { + "temperature": 0.5, + "max_new_tokens": 4, + } + + os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = "True" + sgl_engine = sgl.Engine( + model_path=MODEL_ID, + skip_tokenizer_init=True, + trust_remote_code=True, + mem_fraction_static=0.60, + ) + + for prompt in PROMPTS: + enc = self.tokenizer(prompt, return_tensors="pt") + input_ids = enc["input_ids"][0].tolist() + prompt_len = len(input_ids) + + hf_all_lp = self.hf_logprobs_for_sequence(input_ids, temperature=1.0) + + for start_len in [0, 1, prompt_len // 2, prompt_len - 1]: + with self.subTest(prompt=prompt, logprob_start_len=start_len): + outputs = sgl_engine.generate( + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=True, + logprob_start_len=start_len, + top_logprobs_num=5, + ) + + if isinstance(outputs, list): + outputs = outputs[0] + meta = outputs["meta_info"] + + expected_input_len = prompt_len - start_len + self.assertEqual( + len(meta["input_token_logprobs"]), + expected_input_len, + msg=f"input_token_logprobs length mismatch " + f"for start_len={start_len}", + ) + self.assertEqual( + len(meta["input_top_logprobs"]), + expected_input_len, + msg=f"input_top_logprobs length mismatch " + f"for start_len={start_len}", + ) + self.assertEqual( + len(meta["output_token_logprobs"]), + sampling_params["max_new_tokens"], + ) + self.assertEqual( + meta["prompt_tokens"], + start_len + len(meta["input_token_logprobs"]), + ) + + # Spot-check: verify the first returned position against HF + if start_len > 0 and expected_input_len > 0: + first_lp = meta["input_token_logprobs"][0] + hf_lp = hf_all_lp[start_len - 1] + sgl_val = first_lp[0] + if sgl_val is not None: + hf_val = hf_lp[first_lp[1]].item() + self.assertTrue( + torch.allclose( + torch.tensor([hf_val]), + torch.tensor([float(sgl_val)]), + rtol=RTOL, + atol=ATOL, + ), + msg=f"First position (start_len={start_len}) " + f"HF={hf_val:.6f} vs SGL={sgl_val:.6f}", + ) + + print( + f"[logprob_start_len={start_len} {prompt}] " + f"input_logprobs_len=" + f"{len(meta['input_token_logprobs'])}, " + f"output_logprobs_len=" + f"{len(meta['output_token_logprobs'])}" + ) + + sgl_engine.shutdown() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/sampling/test_original_logprobs.py b/test/registered/sampling/test_original_logprobs.py deleted file mode 100644 index ed78daa60b94..000000000000 --- a/test/registered/sampling/test_original_logprobs.py +++ /dev/null @@ -1,199 +0,0 @@ -"""Test original log probability alignment between SGLang and Hugging Face. - -This test suite verifies the correctness of the `origin_logprobs` output (temperature=1) -and the `logprobs` output (temperature=0.5) in SGLang by comparing it against -raw logit-based probabilities computed directly from a reference Hugging Face model. - -The test covers the following scenarios: -- Next-token prediction: Verifies that the log probability of the next token from - SGLang matches the Hugging Face model. -- Top-k logprobs: Ensures that the top-k original logprobs returned by SGLang are - consistent with Hugging Face outputs. -- Specified token IDs: Confirms that the original logprobs for specific token IDs - match the values computed from Hugging Face logits. -""" - -import os -import random -import unittest - -import torch -import torch.nn.functional as F -from transformers import AutoModelForCausalLM, AutoTokenizer - -import sglang as sgl -from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST - -register_cuda_ci(est_time=41, suite="stage-b-test-small-1-gpu") -register_amd_ci(est_time=60, suite="stage-b-test-small-1-gpu-amd") - -# ------------------------- Configurable via env ------------------------- # -MODEL_ID = DEFAULT_SMALL_MODEL_NAME_FOR_TEST -PROMPTS = [ - "Hello, my name is", - "The future of AI is", - "The president of the United States is", - "The capital of France is ", -] -TOP_LOGPROBS_NUM = 50 -NUM_RANDOM_TOKEN_IDS = 10 -RTOL = 0.20 -ATOL = 0.00 -# ------------------------------------------------ - -torch.manual_seed(1234) -if torch.cuda.is_available(): - torch.cuda.manual_seed_all(1234) - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.allow_tf32 = False - - -class TestOriginalLogprob(unittest.TestCase): - def setUp(self): - # ----- HF side (float32 weights) ----- - self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right") - self.hf_model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, torch_dtype=torch.float32, device_map="auto" - ) - - # Shared sampling parameters - self.sampling_params = { - "temperature": 0.5, # SGLang uses 0.5, but original logprobs are used 1.0 - "top_p": 1.0, - "top_k": 10, - "max_new_tokens": 1, - } - - # --------------------------------------------------------------------- - # Helper: compare one SGLang block (token_logprobs / top_logprobs / ids_logprobs) - # against a reference HF log‑prob vector. - # --------------------------------------------------------------------- - def assert_logprobs_block_equal( - self, - hf_log_probs: torch.Tensor, # [V] - token_log_probs: list, - top_log_probs: list, - ids_log_probs: list, - random_token_ids: list, - tag: str = "", - ): - vals, idxs, _ = zip(*token_log_probs) - sgl_vals = torch.tensor(vals, device=self.hf_model.device, dtype=torch.float32) - sgl_idxs = torch.tensor(idxs, device=self.hf_model.device, dtype=torch.long) - hf_vals = hf_log_probs[sgl_idxs] - - self.assertTrue( - torch.allclose(hf_vals, sgl_vals, rtol=RTOL, atol=ATOL), - msg=f"[{tag}] token‑level mismatch at indices {sgl_idxs.tolist()}", - ) - - hf_topk, _ = torch.topk(hf_log_probs, k=TOP_LOGPROBS_NUM, dim=-1) - - sgl_topk = torch.tensor( - [float(t[0]) for t in top_log_probs[0] if t and t[0] is not None][ - :TOP_LOGPROBS_NUM - ], - dtype=torch.float32, - device=self.hf_model.device, - ) - - k = min(hf_topk.numel(), sgl_topk.numel()) - self.assertTrue( - torch.allclose(hf_topk[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL), - msg=f"[{tag}] top‑k mismatch", - ) - - indices = torch.tensor( - random_token_ids, dtype=torch.long, device=hf_log_probs.device - ) - - hf_token_ids = hf_log_probs[indices] - - sgl_token_ids = torch.tensor( - [v for v, _, _ in ids_log_probs[0]], - device=self.hf_model.device, - dtype=torch.float32, - ) - self.assertTrue( - torch.allclose(hf_token_ids, sgl_token_ids, rtol=RTOL, atol=ATOL), - msg=f"[{tag}] token‑IDs mismatch", - ) - - # Optional: print max abs diff for quick diagnostics - max_diff = torch.max(torch.abs(hf_vals - sgl_vals)).item() - print(f"[{tag}] max|diff| token‑level = {max_diff:.4f}") - - def test_logprob_match(self): - vocab_size = self.tokenizer.vocab_size - - for env_val in ["True", "False"]: - with self.subTest(SGLANG_RETURN_ORIGINAL_LOGPROB=env_val): - os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = env_val - - # ----- SGLang side ----- - sgl_engine = sgl.Engine( - model_path=MODEL_ID, - skip_tokenizer_init=True, - trust_remote_code=True, - mem_fraction_static=0.60, - ) - - for prompt in PROMPTS: - random_token_ids = sorted( - random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS) - ) - - enc = self.tokenizer(prompt, return_tensors="pt") - input_ids = enc["input_ids"].to(self.hf_model.device) - attn_mask = enc["attention_mask"].to(self.hf_model.device) - - with torch.inference_mode(): - hf_out = self.hf_model( - input_ids=input_ids, - attention_mask=attn_mask, - return_dict=True, - ) - logits = hf_out.logits[:, -1, :] # [1, V] - hf_log_probs = F.log_softmax( - logits.float() / self.sampling_params["temperature"], dim=-1 - )[0] - hf_original_log_probs = F.log_softmax(logits.float(), dim=-1)[0] - - outputs = sgl_engine.generate( - input_ids=input_ids[0].tolist(), - sampling_params=self.sampling_params, - return_logprob=True, - top_logprobs_num=TOP_LOGPROBS_NUM, - token_ids_logprob=random_token_ids, - ) - - if isinstance(outputs, list): - outputs = outputs[0] - meta = outputs["meta_info"] - - # Check original logprobs only if enabled - if env_val.lower() == "true": - self.assert_logprobs_block_equal( - hf_log_probs=hf_original_log_probs, - token_log_probs=meta["output_token_logprobs"], - top_log_probs=meta["output_top_logprobs"], - ids_log_probs=meta["output_token_ids_logprobs"], - random_token_ids=random_token_ids, - tag=f"Original logprobs SGLang vs HF: {prompt} ({env_val})", - ) - else: - # Always check regular logprobs - self.assert_logprobs_block_equal( - hf_log_probs=hf_log_probs, - token_log_probs=meta["output_token_logprobs"], - top_log_probs=meta["output_top_logprobs"], - ids_log_probs=meta["output_token_ids_logprobs"], - random_token_ids=random_token_ids, - tag=f"logprobs SGLang vs HF: {prompt} ({env_val})", - ) - sgl_engine.shutdown() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/spec/eagle/test_eagle_infer_b.py b/test/registered/spec/eagle/test_eagle_infer_b.py index 011aee68ed77..e38a95c66f0f 100644 --- a/test/registered/spec/eagle/test_eagle_infer_b.py +++ b/test/registered/spec/eagle/test_eagle_infer_b.py @@ -18,6 +18,7 @@ RunningTimeoutTwoWaveMixin, WaitingTimeoutMixin, ) +from sglang.test.kits.logprob_kit import LogprobCrossModeMixin from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test from sglang.test.server_fixtures.eagle_fixture import EagleServerBase from sglang.test.test_utils import DEFAULT_TARGET_MODEL_EAGLE, run_logprob_check @@ -376,5 +377,10 @@ def setUpClass(cls): super().setUpClass() +class TestEAGLELogprobCrossMode(EagleServerBase, LogprobCrossModeMixin): + logprob_decimal_places = 2 + extra_args = ["--chunked-prefill-size", 128, "--max-running-requests", 8] + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/spec/eagle/test_eagle_infer_beta.py b/test/registered/spec/eagle/test_eagle_infer_beta.py index 3124877a9a73..8fe571bca807 100644 --- a/test/registered/spec/eagle/test_eagle_infer_beta.py +++ b/test/registered/spec/eagle/test_eagle_infer_beta.py @@ -8,6 +8,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.kits.logprob_kit import LogprobCrossModeMixin from sglang.test.kits.matched_stop_kit import MatchedStopMixin from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test from sglang.test.test_utils import ( @@ -19,10 +20,13 @@ popen_launch_server, ) -register_cuda_ci(est_time=283, suite="stage-b-test-small-1-gpu") +register_cuda_ci(est_time=400, suite="stage-b-test-small-1-gpu") -class TestEagleServerBase(CustomTestCase, MatchedStopMixin): +class TestEagleServerBase(CustomTestCase, MatchedStopMixin, LogprobCrossModeMixin): + logprob_decimal_places = 2 + logprob_decimal_places = 2 + logprob_decimal_places = 2 max_running_requests = 64 attention_backend = "triton" spec_steps = 5 diff --git a/test/registered/spec/eagle/test_eagle_logprob_correctness.py b/test/registered/spec/eagle/test_eagle_logprob_correctness.py new file mode 100644 index 000000000000..f968e620c10c --- /dev/null +++ b/test/registered/spec/eagle/test_eagle_logprob_correctness.py @@ -0,0 +1,101 @@ +"""Strict log-probability correctness tests for EAGLE speculative decoding. + +This test file uses the cross-mode logprob kit to verify that speculative +decoding (both v1 and v2) produces logprob values within tight tolerance +(decimal places >= 2) of non-speculative prefill scoring. + +Replaces the earlier loose-tolerance checks (max_diff < 0.255 ≈ places=0) +with comprehensive artifact-level comparison using the logprob kit. + +Tested artifacts: + - output_token_logprobs / input_token_logprobs + - output_top_logprobs / input_top_logprobs + - output_token_ids_logprobs / input_token_ids_logprobs + - logprob_start_len boundary correctness + - return_text_in_logprobs structural validation +""" + +import unittest + +from sglang.srt.environ import envs +from sglang.srt.utils.common import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.kits.logprob_kit import LogprobCrossModeMixin +from sglang.test.server_fixtures.eagle_fixture import EagleServerBase +from sglang.test.test_utils import ( + DEFAULT_DRAFT_MODEL_EAGLE, + DEFAULT_TARGET_MODEL_EAGLE, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=300, suite="stage-b-test-large-1-gpu") + + +# --------------------------------------------------------------------------- +# Eagle v1 (non-overlap) with tight logprob tolerance +# --------------------------------------------------------------------------- +class TestEagleV1LogprobCorrectness(EagleServerBase, LogprobCrossModeMixin): + """EAGLE v1: cross-mode logprob correctness at places=2.""" + + logprob_decimal_places = 2 + extra_args = ["--chunked-prefill-size", 128, "--max-running-requests", 8] + + +# --------------------------------------------------------------------------- +# Eagle v2 (overlap / spec-v2) with tight logprob tolerance +# --------------------------------------------------------------------------- +class TestEagleV2LogprobCorrectness(CustomTestCase, LogprobCrossModeMixin): + """EAGLE v2 (spec-v2): cross-mode logprob correctness at places=2.""" + + logprob_decimal_places = 2 + model = DEFAULT_TARGET_MODEL_EAGLE + draft_model = DEFAULT_DRAFT_MODEL_EAGLE + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + launch_args = [ + "--trust-remote-code", + "--attention-backend", + "triton", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model", + cls.draft_model, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "6", + "--page-size", + "1", + "--mem-fraction-static", + "0.75", + "--max-running-requests", + "8", + ] + with envs.SGLANG_ENABLE_SPEC_V2.override( + True + ), envs.SGLANG_SPEC_NAN_DETECTION.override( + True + ), envs.SGLANG_SPEC_OOB_DETECTION.override( + True + ): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=launch_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + +if __name__ == "__main__": + unittest.main() From 60c6bd203a3399d43dbfe582a6c4a87d7034013e Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Sat, 14 Mar 2026 02:47:33 +0000 Subject: [PATCH 02/12] revert back --- test/registered/{logprob => sampling}/test_original_logprobs.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/registered/{logprob => sampling}/test_original_logprobs.py (100%) diff --git a/test/registered/logprob/test_original_logprobs.py b/test/registered/sampling/test_original_logprobs.py similarity index 100% rename from test/registered/logprob/test_original_logprobs.py rename to test/registered/sampling/test_original_logprobs.py From f2a580747a872319780da10d3b2cf9416d29fe51 Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Sat, 14 Mar 2026 04:07:39 +0000 Subject: [PATCH 03/12] upd --- python/sglang/test/kits/logprob_kit.py | 122 +++++++++++++++- .../sampling/test_original_logprobs.py | 132 +++++------------- 2 files changed, 154 insertions(+), 100 deletions(-) diff --git a/python/sglang/test/kits/logprob_kit.py b/python/sglang/test/kits/logprob_kit.py index 96dcca2b65b7..9c5032f9ac57 100644 --- a/python/sglang/test/kits/logprob_kit.py +++ b/python/sglang/test/kits/logprob_kit.py @@ -32,6 +32,19 @@ class TestMySpecDecoding(EagleServerBase, LogprobCrossModeMixin): class TestMyServer(CustomTestCase): def test_logprobs(self): run_logprob_cross_mode_check(self, self.base_url) + +HF ground-truth comparison utilities: + + from sglang.test.kits.logprob_kit import ( + hf_logprobs_for_sequence, + assert_position_logprobs_match, + ) + + hf_lp = hf_logprobs_for_sequence(hf_model, token_ids, temperature=1.0) + assert_position_logprobs_match( + test_case, ref_lp_vec=hf_lp[pos], sgl_token_logprob=..., + top_k=50, rtol=0.20, atol=0.0, tag="...", + ) """ import numpy as np @@ -179,7 +192,114 @@ def _compare_ids_logprobs(test_case, target_ids, baseline_ids, places, tag): # --------------------------------------------------------------------------- -# Public API – standalone functions +# Public API – HF ground-truth comparison utilities +# --------------------------------------------------------------------------- + + +def hf_logprobs_for_sequence(hf_model, token_ids, temperature=1.0): + """Run a HuggingFace forward pass and return per-position log-prob vectors. + + Returns a ``[T, V]`` tensor where ``logprobs[i]`` is the log-softmax of + ``logits[i] / temperature``. ``logits[i]`` predicts ``token[i+1]``. + + Args: + hf_model: A HuggingFace ``AutoModelForCausalLM`` instance. + token_ids: List of token IDs forming the sequence. + temperature: Temperature divisor applied before log-softmax. + """ + import torch + import torch.nn.functional as F + + input_tensor = torch.tensor([token_ids], dtype=torch.long, device=hf_model.device) + with torch.inference_mode(): + logits = hf_model(input_ids=input_tensor, return_dict=True).logits[0] + if temperature != 1.0: + logits = logits / temperature + return F.log_softmax(logits.float(), dim=-1) + + +def assert_position_logprobs_match( + test_case, + ref_lp_vec, + sgl_token_logprob, + sgl_top_logprobs=None, + sgl_ids_logprobs=None, + token_ids=None, + top_k=5, + rtol=0.20, + atol=0.0, + tag="", +): + """Compare SGLang logprob output against a reference vector at one position. + + Checks token-level logprob, top-k logprobs, and specified-token-ID + logprobs against the reference ``[V]`` log-prob vector (e.g. from HF). + + Positions whose SGLang logprob value is ``None`` (e.g. position 0 of + input logprobs) are silently skipped. + + Args: + test_case: ``unittest.TestCase`` instance for assertions. + ref_lp_vec: ``[V]`` reference log-prob vector (torch.Tensor). + sgl_token_logprob: ``(logprob, token_id, text)`` tuple from SGLang. + sgl_top_logprobs: List of ``(logprob, token_id, text)`` tuples, or + ``None`` to skip top-k comparison. + sgl_ids_logprobs: List of ``(logprob, token_id, text)`` tuples for + user-specified IDs, or ``None`` to skip. + token_ids: Token IDs corresponding to *sgl_ids_logprobs*. + top_k: Number of top logprobs to compare. + rtol: Relative tolerance for ``torch.allclose``. + atol: Absolute tolerance for ``torch.allclose``. + tag: Descriptive tag for error messages. + """ + import torch + + sgl_val, sgl_idx = sgl_token_logprob[0], sgl_token_logprob[1] + if sgl_val is None: + return + + ref_val = ref_lp_vec[sgl_idx].item() + test_case.assertTrue( + torch.allclose( + torch.tensor([ref_val]), + torch.tensor([float(sgl_val)]), + rtol=rtol, + atol=atol, + ), + msg=f"[{tag}] token logprob mismatch: ref={ref_val:.6f} vs SGL={sgl_val:.6f}", + ) + + if sgl_top_logprobs is not None: + ref_topk_vals, _ = torch.topk(ref_lp_vec, k=top_k, dim=-1) + sgl_vals = [float(t[0]) for t in sgl_top_logprobs if t and t[0] is not None] + if sgl_vals: + sgl_topk = torch.tensor( + sgl_vals[:top_k], + dtype=torch.float32, + device=ref_lp_vec.device, + ) + k = min(ref_topk_vals.numel(), sgl_topk.numel()) + test_case.assertTrue( + torch.allclose(ref_topk_vals[:k], sgl_topk[:k], rtol=rtol, atol=atol), + msg=f"[{tag}] top-k mismatch", + ) + + if sgl_ids_logprobs is not None and token_ids: + ids_tensor = torch.tensor(token_ids, dtype=torch.long, device=ref_lp_vec.device) + ref_ids_vals = ref_lp_vec[ids_tensor] + sgl_ids_vals = torch.tensor( + [float(v) for v, _, _ in sgl_ids_logprobs], + dtype=torch.float32, + device=ref_lp_vec.device, + ) + test_case.assertTrue( + torch.allclose(ref_ids_vals, sgl_ids_vals, rtol=rtol, atol=atol), + msg=f"[{tag}] token-IDs mismatch", + ) + + +# --------------------------------------------------------------------------- +# Public API – cross-mode comparison functions # --------------------------------------------------------------------------- diff --git a/test/registered/sampling/test_original_logprobs.py b/test/registered/sampling/test_original_logprobs.py index a3b5baaae88a..aef3e69f366a 100644 --- a/test/registered/sampling/test_original_logprobs.py +++ b/test/registered/sampling/test_original_logprobs.py @@ -29,6 +29,10 @@ import sglang as sgl from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.kits.logprob_kit import ( + assert_position_logprobs_match, + hf_logprobs_for_sequence, +) from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST register_cuda_ci(est_time=120, suite="stage-b-test-small-1-gpu") @@ -203,79 +207,6 @@ def test_logprob_match(self): ) sgl_engine.shutdown() - # --------------------------------------------------------------------- - # Helper: compute HF per-position log-prob vectors for a token sequence. - # logprobs[i] = log_softmax(logits[i] / temperature) predicts token[i+1]. - # --------------------------------------------------------------------- - def hf_logprobs_for_sequence(self, token_ids, temperature=1.0): - input_tensor = torch.tensor( - [token_ids], dtype=torch.long, device=self.hf_model.device - ) - with torch.inference_mode(): - logits = self.hf_model(input_ids=input_tensor, return_dict=True).logits[0] - if temperature != 1.0: - logits = logits / temperature - return F.log_softmax(logits.float(), dim=-1) - - # --------------------------------------------------------------------- - # Helper: compare SGLang logprob outputs against HF at a single position. - # --------------------------------------------------------------------- - def assert_position_logprobs_match( - self, - hf_lp_vec, - sgl_token_logprob, - sgl_top_logprobs=None, - sgl_ids_logprobs=None, - random_token_ids=None, - tag="", - ): - sgl_val, sgl_idx = sgl_token_logprob[0], sgl_token_logprob[1] - if sgl_val is None: - return - - hf_val = hf_lp_vec[sgl_idx].item() - self.assertTrue( - torch.allclose( - torch.tensor([hf_val]), - torch.tensor([float(sgl_val)]), - rtol=RTOL, - atol=ATOL, - ), - msg=f"[{tag}] token logprob mismatch: HF={hf_val:.6f} vs SGL={sgl_val:.6f}", - ) - - if sgl_top_logprobs is not None: - hf_topk_vals, _ = torch.topk(hf_lp_vec, k=TOP_LOGPROBS_NUM, dim=-1) - sgl_vals = [float(t[0]) for t in sgl_top_logprobs if t and t[0] is not None] - if sgl_vals: - sgl_topk = torch.tensor( - sgl_vals[:TOP_LOGPROBS_NUM], - dtype=torch.float32, - device=self.hf_model.device, - ) - k = min(hf_topk_vals.numel(), sgl_topk.numel()) - self.assertTrue( - torch.allclose( - hf_topk_vals[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL - ), - msg=f"[{tag}] top-k mismatch", - ) - - if sgl_ids_logprobs is not None and random_token_ids: - ids_tensor = torch.tensor( - random_token_ids, dtype=torch.long, device=hf_lp_vec.device - ) - hf_ids_vals = hf_lp_vec[ids_tensor] - sgl_ids_vals = torch.tensor( - [float(v) for v, _, _ in sgl_ids_logprobs], - dtype=torch.float32, - device=self.hf_model.device, - ) - self.assertTrue( - torch.allclose(hf_ids_vals, sgl_ids_vals, rtol=RTOL, atol=ATOL), - msg=f"[{tag}] token-IDs mismatch", - ) - def test_output_logprob_multi_token(self): """Multi-token decode: validate per-step logprobs across an 8-token generation.""" vocab_size = self.tokenizer.vocab_size @@ -325,19 +256,23 @@ def test_output_logprob_multi_token(self): hf_temp = ( 1.0 if env_val == "True" else sampling_params["temperature"] ) - hf_all_lp = self.hf_logprobs_for_sequence( - full_sequence, temperature=hf_temp + hf_all_lp = hf_logprobs_for_sequence( + self.hf_model, full_sequence, temperature=hf_temp ) max_diff = 0.0 for step in range(MULTI_TOKEN_COUNT): hf_pos = prompt_len - 1 + step - self.assert_position_logprobs_match( - hf_lp_vec=hf_all_lp[hf_pos], + assert_position_logprobs_match( + self, + ref_lp_vec=hf_all_lp[hf_pos], sgl_token_logprob=meta["output_token_logprobs"][step], sgl_top_logprobs=meta["output_top_logprobs"][step], sgl_ids_logprobs=meta["output_token_ids_logprobs"][step], - random_token_ids=random_token_ids, + token_ids=random_token_ids, + top_k=TOP_LOGPROBS_NUM, + rtol=RTOL, + atol=ATOL, tag=f"Multi-token step {step}: {prompt} ({env_val})", ) sgl_val = meta["output_token_logprobs"][step][0] @@ -404,8 +339,8 @@ def test_input_logprobs(self): hf_temp = ( 1.0 if env_val == "True" else sampling_params["temperature"] ) - hf_all_lp = self.hf_logprobs_for_sequence( - input_ids, temperature=hf_temp + hf_all_lp = hf_logprobs_for_sequence( + self.hf_model, input_ids, temperature=hf_temp ) # Skip position 0 (no preceding logits for the first token) @@ -420,12 +355,16 @@ def test_input_logprobs(self): input_ids_logprobs[pos] if input_ids_logprobs else None ) - self.assert_position_logprobs_match( - hf_lp_vec=hf_lp_vec, + assert_position_logprobs_match( + self, + ref_lp_vec=hf_lp_vec, sgl_token_logprob=input_token_logprobs[pos], sgl_top_logprobs=sgl_top, sgl_ids_logprobs=sgl_ids, - random_token_ids=random_token_ids, + token_ids=random_token_ids, + top_k=TOP_LOGPROBS_NUM, + rtol=RTOL, + atol=ATOL, tag=f"Input pos {pos}: {prompt} ({env_val})", ) @@ -461,7 +400,9 @@ def test_logprob_start_len(self): input_ids = enc["input_ids"][0].tolist() prompt_len = len(input_ids) - hf_all_lp = self.hf_logprobs_for_sequence(input_ids, temperature=1.0) + hf_all_lp = hf_logprobs_for_sequence( + self.hf_model, input_ids, temperature=1.0 + ) for start_len in [0, 1, prompt_len // 2, prompt_len - 1]: with self.subTest(prompt=prompt, logprob_start_len=start_len): @@ -501,21 +442,14 @@ def test_logprob_start_len(self): # Spot-check: verify the first returned position against HF if start_len > 0 and expected_input_len > 0: - first_lp = meta["input_token_logprobs"][0] - hf_lp = hf_all_lp[start_len - 1] - sgl_val = first_lp[0] - if sgl_val is not None: - hf_val = hf_lp[first_lp[1]].item() - self.assertTrue( - torch.allclose( - torch.tensor([hf_val]), - torch.tensor([float(sgl_val)]), - rtol=RTOL, - atol=ATOL, - ), - msg=f"First position (start_len={start_len}) " - f"HF={hf_val:.6f} vs SGL={sgl_val:.6f}", - ) + assert_position_logprobs_match( + self, + ref_lp_vec=hf_all_lp[start_len - 1], + sgl_token_logprob=meta["input_token_logprobs"][0], + rtol=RTOL, + atol=ATOL, + tag=f"First position (start_len={start_len})", + ) print( f"[logprob_start_len={start_len} {prompt}] " From 7e7b6e18469f05b375f580dc1ad08b0ac29c652b Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Mon, 16 Mar 2026 20:27:27 +0000 Subject: [PATCH 04/12] upd --- test/registered/spec/eagle/test_eagle_infer_beta.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/registered/spec/eagle/test_eagle_infer_beta.py b/test/registered/spec/eagle/test_eagle_infer_beta.py index 8fe571bca807..5fac3006060c 100644 --- a/test/registered/spec/eagle/test_eagle_infer_beta.py +++ b/test/registered/spec/eagle/test_eagle_infer_beta.py @@ -24,8 +24,6 @@ class TestEagleServerBase(CustomTestCase, MatchedStopMixin, LogprobCrossModeMixin): - logprob_decimal_places = 2 - logprob_decimal_places = 2 logprob_decimal_places = 2 max_running_requests = 64 attention_backend = "triton" From 3b614c49824b2b213f13b084817df67c7cdeac4a Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Mon, 16 Mar 2026 20:46:34 +0000 Subject: [PATCH 05/12] upd --- .../eagle/test_eagle_logprob_correctness.py | 101 ------------------ 1 file changed, 101 deletions(-) delete mode 100644 test/registered/spec/eagle/test_eagle_logprob_correctness.py diff --git a/test/registered/spec/eagle/test_eagle_logprob_correctness.py b/test/registered/spec/eagle/test_eagle_logprob_correctness.py deleted file mode 100644 index f968e620c10c..000000000000 --- a/test/registered/spec/eagle/test_eagle_logprob_correctness.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Strict log-probability correctness tests for EAGLE speculative decoding. - -This test file uses the cross-mode logprob kit to verify that speculative -decoding (both v1 and v2) produces logprob values within tight tolerance -(decimal places >= 2) of non-speculative prefill scoring. - -Replaces the earlier loose-tolerance checks (max_diff < 0.255 ≈ places=0) -with comprehensive artifact-level comparison using the logprob kit. - -Tested artifacts: - - output_token_logprobs / input_token_logprobs - - output_top_logprobs / input_top_logprobs - - output_token_ids_logprobs / input_token_ids_logprobs - - logprob_start_len boundary correctness - - return_text_in_logprobs structural validation -""" - -import unittest - -from sglang.srt.environ import envs -from sglang.srt.utils.common import kill_process_tree -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.kits.logprob_kit import LogprobCrossModeMixin -from sglang.test.server_fixtures.eagle_fixture import EagleServerBase -from sglang.test.test_utils import ( - DEFAULT_DRAFT_MODEL_EAGLE, - DEFAULT_TARGET_MODEL_EAGLE, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - -register_cuda_ci(est_time=300, suite="stage-b-test-large-1-gpu") - - -# --------------------------------------------------------------------------- -# Eagle v1 (non-overlap) with tight logprob tolerance -# --------------------------------------------------------------------------- -class TestEagleV1LogprobCorrectness(EagleServerBase, LogprobCrossModeMixin): - """EAGLE v1: cross-mode logprob correctness at places=2.""" - - logprob_decimal_places = 2 - extra_args = ["--chunked-prefill-size", 128, "--max-running-requests", 8] - - -# --------------------------------------------------------------------------- -# Eagle v2 (overlap / spec-v2) with tight logprob tolerance -# --------------------------------------------------------------------------- -class TestEagleV2LogprobCorrectness(CustomTestCase, LogprobCrossModeMixin): - """EAGLE v2 (spec-v2): cross-mode logprob correctness at places=2.""" - - logprob_decimal_places = 2 - model = DEFAULT_TARGET_MODEL_EAGLE - draft_model = DEFAULT_DRAFT_MODEL_EAGLE - - @classmethod - def setUpClass(cls): - cls.base_url = DEFAULT_URL_FOR_TEST - launch_args = [ - "--trust-remote-code", - "--attention-backend", - "triton", - "--speculative-algorithm", - "EAGLE", - "--speculative-draft-model", - cls.draft_model, - "--speculative-num-steps", - "5", - "--speculative-eagle-topk", - "1", - "--speculative-num-draft-tokens", - "6", - "--page-size", - "1", - "--mem-fraction-static", - "0.75", - "--max-running-requests", - "8", - ] - with envs.SGLANG_ENABLE_SPEC_V2.override( - True - ), envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override( - True - ): - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=launch_args, - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - -if __name__ == "__main__": - unittest.main() From acfe2be1ba7f42c7c61b20ebc7d09ce9a6784bec Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Mon, 16 Mar 2026 22:04:53 +0000 Subject: [PATCH 06/12] upd --- python/sglang/test/kits/logprob_kit.py | 71 ++++++++- .../spec/eagle/test_eagle_infer_b.py | 144 +----------------- .../spec/eagle/test_eagle_infer_beta.py | 2 +- 3 files changed, 73 insertions(+), 144 deletions(-) diff --git a/python/sglang/test/kits/logprob_kit.py b/python/sglang/test/kits/logprob_kit.py index 9c5032f9ac57..f477a53758d0 100644 --- a/python/sglang/test/kits/logprob_kit.py +++ b/python/sglang/test/kits/logprob_kit.py @@ -47,9 +47,15 @@ def test_logprobs(self): ) """ +import random +from concurrent.futures import ThreadPoolExecutor +from functools import partial + import numpy as np import requests +from sglang.test.test_utils import run_logprob_check + CROSS_MODE_PROMPTS = [ "The capital of France is", "Explain quantum computing in simple terms:", @@ -58,7 +64,7 @@ def test_logprobs(self): DEFAULT_MAX_NEW_TOKENS = 32 DEFAULT_TOP_LOGPROBS_NUM = 5 DEFAULT_PROBE_TOKEN_IDS = [1, 2, 10, 100, 1000] -DEFAULT_DECIMAL_PLACES = 2 +DEFAULT_DECIMAL_PLACES = 1 # --------------------------------------------------------------------------- @@ -528,6 +534,64 @@ def run_logprob_start_len_check( ) +def run_logprob_mixed_check( + test_case, + base_url, + input_lens=None, + output_lens=None, + logprob_start_lens=None, + max_workers=8, +): + """Stress-test logprob shape correctness with many parameter combinations. + + Sends concurrent requests with various (input_len, output_len, + logprob_start_len, return_logprob, top_logprobs_num) combos and verifies + that the returned array lengths are all correct. + + Args: + test_case: ``unittest.TestCase`` instance. + base_url: Server URL. + input_lens: List of input lengths to test. + output_lens: List of output lengths to test. + logprob_start_lens: List of logprob_start_len values to test. + max_workers: Concurrency for the thread pool. + """ + if input_lens is None: + input_lens = [200, 500, 1000, 2000] + if output_lens is None: + output_lens = [4, 8] + if logprob_start_lens is None: + logprob_start_lens = [0, 100, 300, 800, 1998] + + args = [] + temperature = 0 + for input_len in input_lens: + for output_len in output_lens: + for start_len in logprob_start_lens: + for return_logprob in [True, False]: + for top_logprobs_num in [0, 5]: + if start_len >= input_len: + continue + args.append( + ( + input_len, + output_len, + temperature, + start_len, + return_logprob, + top_logprobs_num, + ) + ) + + random.shuffle(args) + print(f"[logprob_mixed] running {len(args)} parameter combinations") + func = partial(run_logprob_check, test_case) + with ThreadPoolExecutor(max_workers) as executor: + list(executor.map(func, args)) + + print(f"[logprob_mixed] all {len(args)} combinations passed") + + # --------------------------------------------------------------------------- # Public API – Mixin class # --------------------------------------------------------------------------- @@ -553,6 +617,7 @@ class TestEagleLogprobs(EagleServerBase, LogprobCrossModeMixin): def test_cross_mode_logprobs(self): """Compare decode logprobs against prefill scoring for all artifacts.""" + print(f"Testing cross-mode logprobs for {self.base_url}") run_logprob_cross_mode_check( self, self.base_url, @@ -581,3 +646,7 @@ def test_cross_mode_return_text_in_logprobs(self): return_text_in_logprobs=True, decimal_places=self.logprob_decimal_places, ) + + def test_cross_mode_logprob_mixed(self): + """Stress-test logprob shape correctness with many parameter combos.""" + run_logprob_mixed_check(self, self.base_url) diff --git a/test/registered/spec/eagle/test_eagle_infer_b.py b/test/registered/spec/eagle/test_eagle_infer_b.py index e38a95c66f0f..e4fdfe184f4f 100644 --- a/test/registered/spec/eagle/test_eagle_infer_b.py +++ b/test/registered/spec/eagle/test_eagle_infer_b.py @@ -4,10 +4,8 @@ import time import unittest from concurrent.futures import ThreadPoolExecutor -from functools import partial from types import SimpleNamespace -import numpy as np import requests from sglang.srt.environ import envs @@ -21,7 +19,7 @@ from sglang.test.kits.logprob_kit import LogprobCrossModeMixin from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test from sglang.test.server_fixtures.eagle_fixture import EagleServerBase -from sglang.test.test_utils import DEFAULT_TARGET_MODEL_EAGLE, run_logprob_check +from sglang.test.test_utils import DEFAULT_TARGET_MODEL_EAGLE register_cuda_ci(est_time=1100, suite="stage-b-test-large-1-gpu") @@ -97,144 +95,6 @@ def test_gsm8k(self): # Wait a little bit so that the memory check happens. time.sleep(4) - def test_logprob_start_len(self): - logprob_start_len = 4 - new_tokens = 4 - prompts = [ - "I have a very good idea on", - "Today is a sunndy day and", - ] - - response = requests.post( - self.base_url + "/generate", - json={ - "text": prompts, - "sampling_params": { - "temperature": 0, - "max_new_tokens": new_tokens, - }, - "return_logprob": True, - "top_logprobs_num": 5, - "logprob_start_len": logprob_start_len, - }, - ) - response_json = response.json() - print(json.dumps(response_json, indent=2)) - - for res in response_json: - self.assertEqual( - res["meta_info"]["prompt_tokens"], - logprob_start_len + len(res["meta_info"]["input_token_logprobs"]), - ) - - self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) - self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) - - def test_logprob_match(self): - """Test the output logprobs are close to the input logprobs if we run a prefill again.""" - - def run_generate( - prompt, - return_logprob=False, - max_new_tokens=512, - logprob_start_len=-1, - temperature=1.0, - ): - - if isinstance(prompt, str): - prompt_kwargs = {"text": prompt} - else: - prompt_kwargs = {"input_ids": prompt} - - response = requests.post( - self.base_url + "/generate", - json={ - **prompt_kwargs, - "sampling_params": { - "temperature": temperature, - "max_new_tokens": max_new_tokens, - "ignore_eos": True, - }, - "return_logprob": return_logprob, - "return_text_in_logprobs": True, - "logprob_start_len": logprob_start_len, - "temp_scaled_logprobs": True, - }, - ) - return response.json() - - prompt = "I have a very good idea on how to" - - for temperature in [1.0]: - gen = run_generate( - prompt, - return_logprob=True, - logprob_start_len=0, - temperature=temperature, - ) - output_logprobs = np.array( - [x[0] for x in gen["meta_info"]["output_token_logprobs"]] - ) - num_prompts_tokens = gen["meta_info"]["prompt_tokens"] - - input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]] - output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]] - - new_prompt = input_tokens + output_tokens - score = run_generate( - new_prompt, - return_logprob=True, - logprob_start_len=0, - max_new_tokens=0, - temperature=temperature, - ) - output_logprobs_score = np.array( - [ - x[0] - for x in score["meta_info"]["input_token_logprobs"][ - num_prompts_tokens: - ] - ] - ) - - print(f"{output_logprobs[-10:]=}") - print(f"{output_logprobs_score[-10:]=}") - - diff = np.abs(output_logprobs - output_logprobs_score) - max_diff = np.max(diff) - self.assertLess(max_diff, 0.255) - - def test_logprob_mixed(self): - args = [] - temperature = 0 - # input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num - # Llama 2 context length seems to be only 2k, so we can only test small length. - for input_len in [200, 500, 1000, 2000]: - for output_len in [4, 8]: - for logprob_start_len in [0, 100, 300, 800, 1998]: - for return_logprob in [True, False]: - for top_logprobs_num in [0, 5]: - - if logprob_start_len >= input_len: - continue - - args.append( - ( - input_len, - output_len, - temperature, - logprob_start_len, - return_logprob, - top_logprobs_num, - ) - ) - - random.shuffle(args) - - func = partial(run_logprob_check, self) - with ThreadPoolExecutor(8) as executor: - list(executor.map(func, args)) - def test_penalty_mixed(self): args = [ {}, @@ -378,7 +238,7 @@ def setUpClass(cls): class TestEAGLELogprobCrossMode(EagleServerBase, LogprobCrossModeMixin): - logprob_decimal_places = 2 + logprob_decimal_places = 1 extra_args = ["--chunked-prefill-size", 128, "--max-running-requests", 8] diff --git a/test/registered/spec/eagle/test_eagle_infer_beta.py b/test/registered/spec/eagle/test_eagle_infer_beta.py index 5fac3006060c..8eb666122beb 100644 --- a/test/registered/spec/eagle/test_eagle_infer_beta.py +++ b/test/registered/spec/eagle/test_eagle_infer_beta.py @@ -24,7 +24,7 @@ class TestEagleServerBase(CustomTestCase, MatchedStopMixin, LogprobCrossModeMixin): - logprob_decimal_places = 2 + logprob_decimal_places = 1 max_running_requests = 64 attention_backend = "triton" spec_steps = 5 From 3f46954e1fe11304eec2d5f137a314eda945f108 Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Mon, 16 Mar 2026 23:45:05 +0000 Subject: [PATCH 07/12] upd --- .../spec/eagle/test_eagle_infer_beta.py | 101 ------------------ 1 file changed, 101 deletions(-) diff --git a/test/registered/spec/eagle/test_eagle_infer_beta.py b/test/registered/spec/eagle/test_eagle_infer_beta.py index 8eb666122beb..d346bffb27b9 100644 --- a/test/registered/spec/eagle/test_eagle_infer_beta.py +++ b/test/registered/spec/eagle/test_eagle_infer_beta.py @@ -1,7 +1,6 @@ import unittest from types import SimpleNamespace -import numpy as np import requests from sglang.srt.environ import envs @@ -103,106 +102,6 @@ def test_gsm8k(self): ) # 0.3333 for 60 questions; 0.234 for 1319 questions assert self.process.poll() is None - def test_logprob_spec_v2_match(self): - """Verify spec v2 decode logprobs match prefill scoring logprobs. - - Generate tokens with spec v2, then score the same sequence via - prefill-only (no speculation). The two sets of logprobs should be - close, validating that spec v2 computes logprobs correctly. - - Runs two rounds with different prompts to catch state-dependent bugs. - """ - top_k = 5 - probe_token_ids = [1, 2, 10, 100, 1000] - prompts = [ - "The capital of France is", - "Explain quantum computing in simple terms:", - ] - - for round_idx, prompt in enumerate(prompts): - with self.subTest(round=round_idx, prompt=prompt): - gen_res = requests.post( - self.base_url + "/generate", - json={ - "text": prompt, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 32, - "ignore_eos": True, - }, - "return_logprob": True, - "top_logprobs_num": top_k, - "token_ids_logprob": probe_token_ids, - "logprob_start_len": 0, - }, - ).json() - - decode_logprobs = gen_res["meta_info"]["output_token_logprobs"] - decode_top_logprobs = gen_res["meta_info"]["output_top_logprobs"] - decode_tid_logprobs = gen_res["meta_info"]["output_token_ids_logprobs"] - input_token_ids = [ - t[1] for t in gen_res["meta_info"]["input_token_logprobs"] - ] - output_token_ids = [t[1] for t in decode_logprobs] - num_prompt_tokens = gen_res["meta_info"]["prompt_tokens"] - - score_res = requests.post( - self.base_url + "/generate", - json={ - "input_ids": input_token_ids + output_token_ids, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 0, - }, - "return_logprob": True, - "top_logprobs_num": top_k, - "token_ids_logprob": probe_token_ids, - "logprob_start_len": 0, - }, - ).json() - - score_logprobs = score_res["meta_info"]["input_token_logprobs"][ - num_prompt_tokens: - ] - score_top_logprobs = score_res["meta_info"]["input_top_logprobs"][ - num_prompt_tokens: - ] - score_tid_logprobs = score_res["meta_info"]["input_token_ids_logprobs"][ - num_prompt_tokens: - ] - - self.assertEqual(len(decode_logprobs), len(score_logprobs)) - - # Check per-token logprobs - decode_vals = np.array([t[0] for t in decode_logprobs]) - score_vals = np.array([t[0] for t in score_logprobs]) - max_diff = np.max(np.abs(decode_vals - score_vals)) - print( - f"[round {round_idx}] prompt={prompt!r} " - f"logprob max_diff={max_diff:.6f}" - ) - print(f"[round {round_idx}] decode_vals[-5:]={decode_vals[-5:]}") - print(f"[round {round_idx}] score_vals[-5:]={score_vals[-5:]}") - self.assertLess(max_diff, 0.255) - - # Check top-k logprobs - for pos in range(len(decode_logprobs)): - dec_top = {t[1]: t[0] for t in decode_top_logprobs[pos]} - scr_top = {t[1]: t[0] for t in score_top_logprobs[pos]} - common_ids = set(dec_top.keys()) & set(scr_top.keys()) - self.assertGreater(len(common_ids), 0) - for tid in common_ids: - self.assertAlmostEqual(dec_top[tid], scr_top[tid], delta=0.255) - - # Check token_ids_logprob - self.assertEqual(len(decode_tid_logprobs), len(score_tid_logprobs)) - for pos in range(len(decode_tid_logprobs)): - dec_tid = {t[1]: t[0] for t in decode_tid_logprobs[pos]} - scr_tid = {t[1]: t[0] for t in score_tid_logprobs[pos]} - self.assertEqual(set(dec_tid.keys()), set(scr_tid.keys())) - for tid in dec_tid: - self.assertAlmostEqual(dec_tid[tid], scr_tid[tid], delta=0.255) - def test_token_ids_logprob_ragged(self): """Regression: get_token_ids_logprobs_raw crashes on ragged token_ids_logprob lists. From ff5027f35da605610a63a2adee07ddc750ebbf9c Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Wed, 18 Mar 2026 22:11:41 +0000 Subject: [PATCH 08/12] upd --- python/sglang/test/kits/logprob_kit.py | 58 +- .../sampling/test_original_logprobs.py | 521 +++++------------- 2 files changed, 169 insertions(+), 410 deletions(-) diff --git a/python/sglang/test/kits/logprob_kit.py b/python/sglang/test/kits/logprob_kit.py index f477a53758d0..018f6a7b9df4 100644 --- a/python/sglang/test/kits/logprob_kit.py +++ b/python/sglang/test/kits/logprob_kit.py @@ -14,25 +14,6 @@ - logprob_start_len: boundary correctness - return_text_in_logprobs: structural validation -Usage as a Mixin (recommended for eagle / spec-decoding tests): - - from sglang.test.kits.logprob_kit import LogprobCrossModeMixin - - class TestMySpecDecoding(EagleServerBase, LogprobCrossModeMixin): - logprob_decimal_places = 2 # override tolerance as needed - # test methods are inherited from the mixin - -Usage as standalone functions: - - from sglang.test.kits.logprob_kit import ( - run_logprob_cross_mode_check, - run_logprob_start_len_check, - ) - - class TestMyServer(CustomTestCase): - def test_logprobs(self): - run_logprob_cross_mode_check(self, self.base_url) - HF ground-truth comparison utilities: from sglang.test.kits.logprob_kit import ( @@ -72,27 +53,42 @@ def test_logprobs(self): # --------------------------------------------------------------------------- -def _generate_with_logprobs( +def generate_with_logprobs( url, prompt_or_ids, max_new_tokens, top_logprobs_num, - token_ids_logprob, - logprob_start_len, + token_ids_logprob=None, + logprob_start_len=-1, return_text_in_logprobs=False, + sampling_params=None, ): + """Send a generate request with logprob options to an SGLang server. + + Args: + url: Server base URL. + prompt_or_ids: Text string or list of token IDs. + max_new_tokens: Number of tokens to generate. + top_logprobs_num: Number of top logprobs to return. + token_ids_logprob: Optional token IDs for the token_ids_logprob artifact. + logprob_start_len: Starting position for input logprobs (-1 to skip). + return_text_in_logprobs: Include token text in logprob tuples. + sampling_params: Optional dict of sampling parameters. Defaults to + ``{"temperature": 0, "ignore_eos": True}``. ``max_new_tokens`` + is always set from the explicit argument. + """ if isinstance(prompt_or_ids, str): payload = {"text": prompt_or_ids} else: payload = {"input_ids": prompt_or_ids} + if sampling_params is None: + sampling_params = {"temperature": 0, "ignore_eos": True} + sampling_params = {**sampling_params, "max_new_tokens": max_new_tokens} + payload.update( { - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - "ignore_eos": True, - }, + "sampling_params": sampling_params, "return_logprob": True, "top_logprobs_num": top_logprobs_num, "logprob_start_len": logprob_start_len, @@ -355,7 +351,7 @@ def run_logprob_cross_mode_check( print(f"\n--- Cross-mode check {tag_prefix}: {prompt!r} ---") # Step 1: generate from target with logprob_start_len=0 - gen_res = _generate_with_logprobs( + gen_res = generate_with_logprobs( target_url, prompt, max_new_tokens, @@ -372,7 +368,7 @@ def run_logprob_cross_mode_check( full_sequence = input_token_ids + output_token_ids # Step 2: score the full sequence via prefill on baseline - score_res = _generate_with_logprobs( + score_res = generate_with_logprobs( baseline_url, full_sequence, max_new_tokens=0, @@ -479,7 +475,7 @@ def run_logprob_start_len_check( for prompt in prompts: # Probe to get prompt_tokens - probe_res = _generate_with_logprobs( + probe_res = generate_with_logprobs( target_url, prompt, max_new_tokens=1, @@ -497,7 +493,7 @@ def run_logprob_start_len_check( if sl >= P: continue with test_case.subTest(prompt=prompt, logprob_start_len=sl): - res = _generate_with_logprobs( + res = generate_with_logprobs( target_url, prompt, max_new_tokens, diff --git a/test/registered/sampling/test_original_logprobs.py b/test/registered/sampling/test_original_logprobs.py index aef3e69f366a..819ed2e69dcf 100644 --- a/test/registered/sampling/test_original_logprobs.py +++ b/test/registered/sampling/test_original_logprobs.py @@ -13,10 +13,9 @@ match the values computed from Hugging Face logits. - Multi-token decoding: Validates per-step log-probability accuracy across a complete generation sequence (max_new_tokens=8). -- Input logprobs: Verifies input_token_logprobs and input_top_logprobs against - the Hugging Face reference for input positions. -- logprob_start_len: Ensures correct boundary behavior for the starting index - of input log-probability computation. + +Two test classes run the same suite with SGLANG_RETURN_ORIGINAL_LOGPROB=True +and False respectively, each launching its own server instance. """ import os @@ -24,16 +23,21 @@ import unittest import torch -import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer -import sglang as sgl +from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.kits.logprob_kit import ( assert_position_logprobs_match, + generate_with_logprobs, hf_logprobs_for_sequence, ) -from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) register_cuda_ci(est_time=120, suite="stage-b-test-small-1-gpu") register_amd_ci(est_time=180, suite="stage-b-test-small-1-gpu-amd") @@ -46,12 +50,17 @@ "The president of the United States is", "The capital of France is ", ] +SAMPLING_PARAMS = { + "temperature": 0.5, + "top_p": 1.0, + "top_k": 10, +} TOP_LOGPROBS_NUM = 50 NUM_RANDOM_TOKEN_IDS = 10 RTOL = 0.20 -ATOL = 0.00 +ATOL = 0.04 MULTI_TOKEN_COUNT = 8 -# ------------------------------------------------ +# ----------------------------------------------------------------------- # torch.manual_seed(1234) if torch.cuda.is_available(): @@ -60,406 +69,160 @@ torch.backends.cudnn.allow_tf32 = False -class TestOriginalLogprob(unittest.TestCase): +class _OriginalLogprobBase(unittest.TestCase): + """Base class — subclasses set ``return_original_logprob`` to "True" or "False".""" + + return_original_logprob = None + @classmethod def setUpClass(cls): - # ----- HF side (float32 weights, loaded once for all tests) ----- + if cls is _OriginalLogprobBase: + raise unittest.SkipTest("Base class") + cls.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right") cls.hf_model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, device_map="auto" ) - def setUp(self): - # Shared sampling parameters - self.sampling_params = { - "temperature": 0.5, # SGLang uses 0.5, but original logprobs are used 1.0 - "top_p": 1.0, - "top_k": 10, - "max_new_tokens": 1, - } - - # --------------------------------------------------------------------- - # Helper: compare one SGLang block (token_logprobs / top_logprobs / ids_logprobs) - # against a reference HF log‑prob vector. - # --------------------------------------------------------------------- - def assert_logprobs_block_equal( - self, - hf_log_probs: torch.Tensor, # [V] - token_log_probs: list, - top_log_probs: list, - ids_log_probs: list, - random_token_ids: list, - tag: str = "", - ): - vals, idxs, _ = zip(*token_log_probs) - sgl_vals = torch.tensor(vals, device=self.hf_model.device, dtype=torch.float32) - sgl_idxs = torch.tensor(idxs, device=self.hf_model.device, dtype=torch.long) - hf_vals = hf_log_probs[sgl_idxs] - - self.assertTrue( - torch.allclose(hf_vals, sgl_vals, rtol=RTOL, atol=ATOL), - msg=f"[{tag}] token‑level mismatch at indices {sgl_idxs.tolist()}", - ) - - hf_topk, _ = torch.topk(hf_log_probs, k=TOP_LOGPROBS_NUM, dim=-1) - - sgl_topk = torch.tensor( - [float(t[0]) for t in top_log_probs[0] if t and t[0] is not None][ - :TOP_LOGPROBS_NUM - ], - dtype=torch.float32, - device=self.hf_model.device, + env = os.environ.copy() + env["SGLANG_RETURN_ORIGINAL_LOGPROB"] = cls.return_original_logprob + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + MODEL_ID, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--trust-remote-code", "--mem-fraction-static", "0.60"], + env=env, ) - k = min(hf_topk.numel(), sgl_topk.numel()) - self.assertTrue( - torch.allclose(hf_topk[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL), - msg=f"[{tag}] top‑k mismatch", - ) - - indices = torch.tensor( - random_token_ids, dtype=torch.long, device=hf_log_probs.device - ) - - hf_token_ids = hf_log_probs[indices] - - sgl_token_ids = torch.tensor( - [v for v, _, _ in ids_log_probs[0]], - device=self.hf_model.device, - dtype=torch.float32, - ) - self.assertTrue( - torch.allclose(hf_token_ids, sgl_token_ids, rtol=RTOL, atol=ATOL), - msg=f"[{tag}] token‑IDs mismatch", + @classmethod + def tearDownClass(cls): + if cls is _OriginalLogprobBase: + return + kill_process_tree(cls.process.pid) + + @property + def hf_temperature(self): + """HF reference temperature: 1.0 for original logprobs, else request temp.""" + return ( + 1.0 + if self.return_original_logprob == "True" + else SAMPLING_PARAMS["temperature"] ) - # Optional: print max abs diff for quick diagnostics - max_diff = torch.max(torch.abs(hf_vals - sgl_vals)).item() - print(f"[{tag}] max|diff| token‑level = {max_diff:.4f}") + # --------------------------------------------------------------------- # + # Tests + # --------------------------------------------------------------------- # def test_logprob_match(self): - vocab_size = self.tokenizer.vocab_size - - for env_val in ["True", "False"]: - with self.subTest(SGLANG_RETURN_ORIGINAL_LOGPROB=env_val): - os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = env_val - - # ----- SGLang side ----- - sgl_engine = sgl.Engine( - model_path=MODEL_ID, - skip_tokenizer_init=True, - trust_remote_code=True, - mem_fraction_static=0.60, - ) - - for prompt in PROMPTS: - random_token_ids = sorted( - random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS) + """Single-token: verify next-token logprobs against HF.""" + for prompt in PROMPTS: + with self.subTest(prompt=prompt): + random_token_ids = sorted( + random.sample( + range(self.tokenizer.vocab_size), NUM_RANDOM_TOKEN_IDS ) + ) + input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][ + 0 + ].tolist() + + res = generate_with_logprobs( + self.base_url, + input_ids, + max_new_tokens=1, + top_logprobs_num=TOP_LOGPROBS_NUM, + token_ids_logprob=random_token_ids, + sampling_params=SAMPLING_PARAMS, + ) + meta = res["meta_info"] - enc = self.tokenizer(prompt, return_tensors="pt") - input_ids = enc["input_ids"].to(self.hf_model.device) - attn_mask = enc["attention_mask"].to(self.hf_model.device) - - with torch.inference_mode(): - hf_out = self.hf_model( - input_ids=input_ids, - attention_mask=attn_mask, - return_dict=True, - ) - logits = hf_out.logits[:, -1, :] # [1, V] - hf_log_probs = F.log_softmax( - logits.float() / self.sampling_params["temperature"], dim=-1 - )[0] - hf_original_log_probs = F.log_softmax(logits.float(), dim=-1)[0] - - outputs = sgl_engine.generate( - input_ids=input_ids[0].tolist(), - sampling_params=self.sampling_params, - return_logprob=True, - top_logprobs_num=TOP_LOGPROBS_NUM, - token_ids_logprob=random_token_ids, - ) + hf_lp = hf_logprobs_for_sequence( + self.hf_model, input_ids, temperature=self.hf_temperature + ) - if isinstance(outputs, list): - outputs = outputs[0] - meta = outputs["meta_info"] - - # Check original logprobs only if enabled - if env_val.lower() == "true": - self.assert_logprobs_block_equal( - hf_log_probs=hf_original_log_probs, - token_log_probs=meta["output_token_logprobs"], - top_log_probs=meta["output_top_logprobs"], - ids_log_probs=meta["output_token_ids_logprobs"], - random_token_ids=random_token_ids, - tag=f"Original logprobs SGLang vs HF: {prompt} ({env_val})", - ) - else: - # Always check regular logprobs - self.assert_logprobs_block_equal( - hf_log_probs=hf_log_probs, - token_log_probs=meta["output_token_logprobs"], - top_log_probs=meta["output_top_logprobs"], - ids_log_probs=meta["output_token_ids_logprobs"], - random_token_ids=random_token_ids, - tag=f"logprobs SGLang vs HF: {prompt} ({env_val})", - ) - sgl_engine.shutdown() + assert_position_logprobs_match( + self, + ref_lp_vec=hf_lp[-1], + sgl_token_logprob=meta["output_token_logprobs"][0], + sgl_top_logprobs=meta["output_top_logprobs"][0], + sgl_ids_logprobs=meta["output_token_ids_logprobs"][0], + token_ids=random_token_ids, + top_k=TOP_LOGPROBS_NUM, + rtol=RTOL, + atol=ATOL, + tag=f"SGLang vs HF: {prompt} ({self.return_original_logprob})", + ) def test_output_logprob_multi_token(self): """Multi-token decode: validate per-step logprobs across an 8-token generation.""" - vocab_size = self.tokenizer.vocab_size - sampling_params = { - "temperature": 0.5, - "top_p": 1.0, - "top_k": 10, - "max_new_tokens": MULTI_TOKEN_COUNT, - } - - for env_val in ["True", "False"]: - with self.subTest(SGLANG_RETURN_ORIGINAL_LOGPROB=env_val): - os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = env_val - - sgl_engine = sgl.Engine( - model_path=MODEL_ID, - skip_tokenizer_init=True, - trust_remote_code=True, - mem_fraction_static=0.60, - ) - - for prompt in PROMPTS: - random_token_ids = sorted( - random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS) - ) - - enc = self.tokenizer(prompt, return_tensors="pt") - input_ids = enc["input_ids"][0].tolist() - prompt_len = len(input_ids) - - outputs = sgl_engine.generate( - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=True, - top_logprobs_num=TOP_LOGPROBS_NUM, - token_ids_logprob=random_token_ids, - ) - - if isinstance(outputs, list): - outputs = outputs[0] - meta = outputs["meta_info"] - - output_token_ids = [t[1] for t in meta["output_token_logprobs"]] - self.assertEqual(len(output_token_ids), MULTI_TOKEN_COUNT) - - full_sequence = input_ids + output_token_ids - hf_temp = ( - 1.0 if env_val == "True" else sampling_params["temperature"] - ) - hf_all_lp = hf_logprobs_for_sequence( - self.hf_model, full_sequence, temperature=hf_temp - ) - - max_diff = 0.0 - for step in range(MULTI_TOKEN_COUNT): - hf_pos = prompt_len - 1 + step - assert_position_logprobs_match( - self, - ref_lp_vec=hf_all_lp[hf_pos], - sgl_token_logprob=meta["output_token_logprobs"][step], - sgl_top_logprobs=meta["output_top_logprobs"][step], - sgl_ids_logprobs=meta["output_token_ids_logprobs"][step], - token_ids=random_token_ids, - top_k=TOP_LOGPROBS_NUM, - rtol=RTOL, - atol=ATOL, - tag=f"Multi-token step {step}: {prompt} ({env_val})", - ) - sgl_val = meta["output_token_logprobs"][step][0] - hf_val = hf_all_lp[hf_pos][ - meta["output_token_logprobs"][step][1] - ].item() - max_diff = max(max_diff, abs(hf_val - sgl_val)) - - print( - f"[Multi-token {prompt} ({env_val})] max|diff| = {max_diff:.4f}" + for prompt in PROMPTS: + with self.subTest(prompt=prompt): + random_token_ids = sorted( + random.sample( + range(self.tokenizer.vocab_size), NUM_RANDOM_TOKEN_IDS ) - - sgl_engine.shutdown() - - def test_input_logprobs(self): - """Input logprobs: verify input_token_logprobs and input_top_logprobs against HF.""" - vocab_size = self.tokenizer.vocab_size - sampling_params = { - "temperature": 0.5, - "top_p": 1.0, - "top_k": 10, - "max_new_tokens": 1, - } - - for env_val in ["True", "False"]: - with self.subTest(SGLANG_RETURN_ORIGINAL_LOGPROB=env_val): - os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = env_val - - sgl_engine = sgl.Engine( - model_path=MODEL_ID, - skip_tokenizer_init=True, - trust_remote_code=True, - mem_fraction_static=0.60, ) + input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][ + 0 + ].tolist() + prompt_len = len(input_ids) + + res = generate_with_logprobs( + self.base_url, + input_ids, + max_new_tokens=MULTI_TOKEN_COUNT, + top_logprobs_num=TOP_LOGPROBS_NUM, + token_ids_logprob=random_token_ids, + sampling_params=SAMPLING_PARAMS, + ) + meta = res["meta_info"] - for prompt in PROMPTS: - random_token_ids = sorted( - random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS) - ) - - enc = self.tokenizer(prompt, return_tensors="pt") - input_ids = enc["input_ids"][0].tolist() - prompt_len = len(input_ids) - - outputs = sgl_engine.generate( - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=True, - logprob_start_len=0, - top_logprobs_num=TOP_LOGPROBS_NUM, - token_ids_logprob=random_token_ids, - ) - - if isinstance(outputs, list): - outputs = outputs[0] - meta = outputs["meta_info"] - - input_token_logprobs = meta["input_token_logprobs"] - input_top_logprobs = meta["input_top_logprobs"] - input_ids_logprobs = meta.get("input_token_ids_logprobs") - - self.assertEqual(len(input_token_logprobs), prompt_len) + output_token_ids = [t[1] for t in meta["output_token_logprobs"]] + self.assertEqual(len(output_token_ids), MULTI_TOKEN_COUNT) - hf_temp = ( - 1.0 if env_val == "True" else sampling_params["temperature"] - ) - hf_all_lp = hf_logprobs_for_sequence( - self.hf_model, input_ids, temperature=hf_temp - ) + full_sequence = input_ids + output_token_ids + hf_all_lp = hf_logprobs_for_sequence( + self.hf_model, full_sequence, temperature=self.hf_temperature + ) - # Skip position 0 (no preceding logits for the first token) - max_diff = 0.0 - for pos in range(1, prompt_len): - hf_lp_vec = hf_all_lp[pos - 1] - - sgl_top = ( - input_top_logprobs[pos] if input_top_logprobs else None - ) - sgl_ids = ( - input_ids_logprobs[pos] if input_ids_logprobs else None - ) - - assert_position_logprobs_match( - self, - ref_lp_vec=hf_lp_vec, - sgl_token_logprob=input_token_logprobs[pos], - sgl_top_logprobs=sgl_top, - sgl_ids_logprobs=sgl_ids, - token_ids=random_token_ids, - top_k=TOP_LOGPROBS_NUM, - rtol=RTOL, - atol=ATOL, - tag=f"Input pos {pos}: {prompt} ({env_val})", - ) - - sgl_val = input_token_logprobs[pos][0] - if sgl_val is not None: - hf_val = hf_lp_vec[input_token_logprobs[pos][1]].item() - max_diff = max(max_diff, abs(hf_val - sgl_val)) - - print( - f"[Input logprobs {prompt} ({env_val})] " - f"max|diff| = {max_diff:.4f}" - ) + max_diff = 0.0 + for step in range(MULTI_TOKEN_COUNT): + hf_pos = prompt_len - 1 + step + assert_position_logprobs_match( + self, + ref_lp_vec=hf_all_lp[hf_pos], + sgl_token_logprob=meta["output_token_logprobs"][step], + sgl_top_logprobs=meta["output_top_logprobs"][step], + sgl_ids_logprobs=meta["output_token_ids_logprobs"][step], + token_ids=random_token_ids, + top_k=TOP_LOGPROBS_NUM, + rtol=RTOL, + atol=ATOL, + tag=f"Multi-token step {step}: {prompt} ({self.return_original_logprob})", + ) + sgl_val = meta["output_token_logprobs"][step][0] + hf_val = hf_all_lp[hf_pos][ + meta["output_token_logprobs"][step][1] + ].item() + max_diff = max(max_diff, abs(hf_val - sgl_val)) + + print( + f"[Multi-token {prompt} ({self.return_original_logprob})] " + f"max|diff| = {max_diff:.4f}" + ) - sgl_engine.shutdown() - - def test_logprob_start_len(self): - """Verify logprob_start_len correctly controls the starting index.""" - sampling_params = { - "temperature": 0.5, - "max_new_tokens": 4, - } - - os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = "True" - sgl_engine = sgl.Engine( - model_path=MODEL_ID, - skip_tokenizer_init=True, - trust_remote_code=True, - mem_fraction_static=0.60, - ) - for prompt in PROMPTS: - enc = self.tokenizer(prompt, return_tensors="pt") - input_ids = enc["input_ids"][0].tolist() - prompt_len = len(input_ids) - - hf_all_lp = hf_logprobs_for_sequence( - self.hf_model, input_ids, temperature=1.0 - ) - - for start_len in [0, 1, prompt_len // 2, prompt_len - 1]: - with self.subTest(prompt=prompt, logprob_start_len=start_len): - outputs = sgl_engine.generate( - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=True, - logprob_start_len=start_len, - top_logprobs_num=5, - ) +class TestOriginalLogprobEnabled(_OriginalLogprobBase): + """Tests with SGLANG_RETURN_ORIGINAL_LOGPROB=True.""" - if isinstance(outputs, list): - outputs = outputs[0] - meta = outputs["meta_info"] + return_original_logprob = "True" - expected_input_len = prompt_len - start_len - self.assertEqual( - len(meta["input_token_logprobs"]), - expected_input_len, - msg=f"input_token_logprobs length mismatch " - f"for start_len={start_len}", - ) - self.assertEqual( - len(meta["input_top_logprobs"]), - expected_input_len, - msg=f"input_top_logprobs length mismatch " - f"for start_len={start_len}", - ) - self.assertEqual( - len(meta["output_token_logprobs"]), - sampling_params["max_new_tokens"], - ) - self.assertEqual( - meta["prompt_tokens"], - start_len + len(meta["input_token_logprobs"]), - ) - # Spot-check: verify the first returned position against HF - if start_len > 0 and expected_input_len > 0: - assert_position_logprobs_match( - self, - ref_lp_vec=hf_all_lp[start_len - 1], - sgl_token_logprob=meta["input_token_logprobs"][0], - rtol=RTOL, - atol=ATOL, - tag=f"First position (start_len={start_len})", - ) - - print( - f"[logprob_start_len={start_len} {prompt}] " - f"input_logprobs_len=" - f"{len(meta['input_token_logprobs'])}, " - f"output_logprobs_len=" - f"{len(meta['output_token_logprobs'])}" - ) +class TestOriginalLogprobDisabled(_OriginalLogprobBase): + """Tests with SGLANG_RETURN_ORIGINAL_LOGPROB=False.""" - sgl_engine.shutdown() + return_original_logprob = "False" if __name__ == "__main__": From e58129fa4c0fadb12f11806b6ff48f5c9c864b2e Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Thu, 19 Mar 2026 01:47:43 +0000 Subject: [PATCH 09/12] fix --- .../managers/scheduler_output_processor_mixin.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index f4b2b9978048..7f7da472e5e4 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -1030,6 +1030,8 @@ def stream_output_generation( and not req.input_logprob_sent # Decode server does not send input logprobs and self.disaggregation_mode != DisaggregationMode.DECODE + # Only send when input logprobs have been computed (after prefill) + and req.input_token_logprobs_val is not None ): input_token_logprobs_val.append(req.input_token_logprobs_val) input_token_logprobs_idx.append(req.input_token_logprobs_idx) @@ -1051,34 +1053,35 @@ def stream_output_generation( input_token_ids_logprobs_idx.append([]) if req.return_logprob: + logprob_end = req.finished_len output_token_logprobs_val.append( req.output_token_logprobs_val[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) output_token_logprobs_idx.append( req.output_token_logprobs_idx[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) output_top_logprobs_val.append( req.output_top_logprobs_val[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) output_top_logprobs_idx.append( req.output_top_logprobs_idx[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) output_token_ids_logprobs_val.append( req.output_token_ids_logprobs_val[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) output_token_ids_logprobs_idx.append( req.output_token_ids_logprobs_idx[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) req.send_output_token_logprobs_offset = len( From e4729de88fb1778d5fde9fade3df1ef79cdde3b6 Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Thu, 19 Mar 2026 08:08:46 +0000 Subject: [PATCH 10/12] upd --- python/sglang/test/kits/logprob_kit.py | 119 +++++++++++-------------- 1 file changed, 51 insertions(+), 68 deletions(-) diff --git a/python/sglang/test/kits/logprob_kit.py b/python/sglang/test/kits/logprob_kit.py index 018f6a7b9df4..1044bc86313d 100644 --- a/python/sglang/test/kits/logprob_kit.py +++ b/python/sglang/test/kits/logprob_kit.py @@ -14,18 +14,6 @@ - logprob_start_len: boundary correctness - return_text_in_logprobs: structural validation -HF ground-truth comparison utilities: - - from sglang.test.kits.logprob_kit import ( - hf_logprobs_for_sequence, - assert_position_logprobs_match, - ) - - hf_lp = hf_logprobs_for_sequence(hf_model, token_ids, temperature=1.0) - assert_position_logprobs_match( - test_case, ref_lp_vec=hf_lp[pos], sgl_token_logprob=..., - top_k=50, rtol=0.20, atol=0.0, tag="...", - ) """ import random @@ -53,56 +41,6 @@ # --------------------------------------------------------------------------- -def generate_with_logprobs( - url, - prompt_or_ids, - max_new_tokens, - top_logprobs_num, - token_ids_logprob=None, - logprob_start_len=-1, - return_text_in_logprobs=False, - sampling_params=None, -): - """Send a generate request with logprob options to an SGLang server. - - Args: - url: Server base URL. - prompt_or_ids: Text string or list of token IDs. - max_new_tokens: Number of tokens to generate. - top_logprobs_num: Number of top logprobs to return. - token_ids_logprob: Optional token IDs for the token_ids_logprob artifact. - logprob_start_len: Starting position for input logprobs (-1 to skip). - return_text_in_logprobs: Include token text in logprob tuples. - sampling_params: Optional dict of sampling parameters. Defaults to - ``{"temperature": 0, "ignore_eos": True}``. ``max_new_tokens`` - is always set from the explicit argument. - """ - if isinstance(prompt_or_ids, str): - payload = {"text": prompt_or_ids} - else: - payload = {"input_ids": prompt_or_ids} - - if sampling_params is None: - sampling_params = {"temperature": 0, "ignore_eos": True} - sampling_params = {**sampling_params, "max_new_tokens": max_new_tokens} - - payload.update( - { - "sampling_params": sampling_params, - "return_logprob": True, - "top_logprobs_num": top_logprobs_num, - "logprob_start_len": logprob_start_len, - "return_text_in_logprobs": return_text_in_logprobs, - } - ) - if token_ids_logprob is not None: - payload["token_ids_logprob"] = token_ids_logprob - - response = requests.post(url + "/generate", json=payload) - assert response.status_code == 200, f"Server error: {response.text}" - return response.json() - - def _compare_token_logprobs(test_case, target_lps, baseline_lps, places, tag): """Compare per-token logprob values, skipping None entries.""" test_case.assertEqual( @@ -194,10 +132,60 @@ def _compare_ids_logprobs(test_case, target_ids, baseline_ids, places, tag): # --------------------------------------------------------------------------- -# Public API – HF ground-truth comparison utilities +# Public API # --------------------------------------------------------------------------- +def generate_with_logprobs( + url, + prompt_or_ids, + max_new_tokens, + top_logprobs_num, + token_ids_logprob=None, + logprob_start_len=-1, + return_text_in_logprobs=False, + sampling_params=None, +): + """Send a generate request with logprob options to an SGLang server. + + Args: + url: Server base URL. + prompt_or_ids: Text string or list of token IDs. + max_new_tokens: Number of tokens to generate. + top_logprobs_num: Number of top logprobs to return. + token_ids_logprob: Optional token IDs for the token_ids_logprob artifact. + logprob_start_len: Starting position for input logprobs (-1 to skip). + return_text_in_logprobs: Include token text in logprob tuples. + sampling_params: Optional dict of sampling parameters. Defaults to + ``{"temperature": 0, "ignore_eos": True}``. ``max_new_tokens`` + is always set from the explicit argument. + """ + if isinstance(prompt_or_ids, str): + payload = {"text": prompt_or_ids} + else: + payload = {"input_ids": prompt_or_ids} + + if sampling_params is None: + sampling_params = {"temperature": 0, "ignore_eos": True} + sampling_params = {**sampling_params, "max_new_tokens": max_new_tokens} + + payload.update( + { + "sampling_params": sampling_params, + "return_logprob": True, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": logprob_start_len, + "return_text_in_logprobs": return_text_in_logprobs, + } + ) + if token_ids_logprob is not None: + payload["token_ids_logprob"] = token_ids_logprob + + response = requests.post(url + "/generate", json=payload) + assert response.status_code == 200, f"Server error: {response.text}" + return response.json() + + def hf_logprobs_for_sequence(hf_model, token_ids, temperature=1.0): """Run a HuggingFace forward pass and return per-position log-prob vectors. @@ -300,11 +288,6 @@ def assert_position_logprobs_match( ) -# --------------------------------------------------------------------------- -# Public API – cross-mode comparison functions -# --------------------------------------------------------------------------- - - def run_logprob_cross_mode_check( test_case, target_url, From ee91344115fad1d29fb87ae391b99472417bcf69 Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Thu, 19 Mar 2026 08:43:31 +0000 Subject: [PATCH 11/12] upd --- .../scheduler_output_processor_mixin.py | 2 +- python/sglang/test/kits/logprob_kit.py | 123 ++++++++++++++---- python/sglang/test/test_utils.py | 75 ----------- test/registered/core/test_srt_endpoint.py | 2 +- 4 files changed, 102 insertions(+), 100 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 7f7da472e5e4..2c00f187a499 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -1084,7 +1084,7 @@ def stream_output_generation( send_output_token_logprobs_offset:logprob_end ] ) - req.send_output_token_logprobs_offset = len( + req.send_output_token_logprobs_offset = logprob_end or len( req.output_token_logprobs_val ) else: diff --git a/python/sglang/test/kits/logprob_kit.py b/python/sglang/test/kits/logprob_kit.py index 1044bc86313d..12f8bc66c660 100644 --- a/python/sglang/test/kits/logprob_kit.py +++ b/python/sglang/test/kits/logprob_kit.py @@ -17,14 +17,13 @@ """ import random +import unittest from concurrent.futures import ThreadPoolExecutor from functools import partial import numpy as np import requests -from sglang.test.test_utils import run_logprob_check - CROSS_MODE_PROMPTS = [ "The capital of France is", "Explain quantum computing in simple terms:", @@ -182,7 +181,8 @@ def generate_with_logprobs( payload["token_ids_logprob"] = token_ids_logprob response = requests.post(url + "/generate", json=payload) - assert response.status_code == 200, f"Server error: {response.text}" + if response.status_code != 200: + raise RuntimeError(f"Server returned {response.status_code}: {response.text}") return response.json() @@ -290,18 +290,17 @@ def assert_position_logprobs_match( def run_logprob_cross_mode_check( test_case, - target_url, - baseline_url=None, prompts=None, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, top_logprobs_num=DEFAULT_TOP_LOGPROBS_NUM, token_ids_logprob=None, return_text_in_logprobs=False, decimal_places=DEFAULT_DECIMAL_PLACES, + baseline_url=None, ): """Run cross-mode logprob comparison across all artifacts. - 1. Generate tokens from the target server (with logprob_start_len=0). + 1. Generate tokens from ``test_case.base_url`` (with logprob_start_len=0). 2. Score the full sequence via prefill on the baseline server. 3. Compare output_token_logprobs, input_token_logprobs, output_top_logprobs, input_top_logprobs, output_token_ids_logprobs, input_token_ids_logprobs. @@ -312,16 +311,16 @@ def run_logprob_cross_mode_check( non-speculative baseline for its own decode logprobs. Args: - test_case: ``unittest.TestCase`` instance for assertions. - target_url: URL of the target server (e.g. speculative decoding). - baseline_url: URL of the baseline server. Defaults to *target_url*. + test_case: ``unittest.TestCase`` instance (must have ``base_url``). prompts: List of prompt strings. max_new_tokens: Tokens to generate per prompt. top_logprobs_num: Top-k count for top_logprobs. token_ids_logprob: Token IDs for the token_ids_logprob artifact. return_text_in_logprobs: Whether to include token text. decimal_places: ``assertAlmostEqual`` precision (``places``). + baseline_url: URL of the baseline server. Defaults to ``test_case.base_url``. """ + target_url = test_case.base_url if baseline_url is None: baseline_url = target_url if prompts is None: @@ -441,23 +440,22 @@ def run_logprob_cross_mode_check( def run_logprob_start_len_check( test_case, - target_url, prompts=None, max_new_tokens=8, start_lens=None, ): - """Verify logprob_start_len boundary correctness on the target server. + """Verify logprob_start_len boundary correctness. - For each prompt and start_len, verifies that: + Uses ``test_case.base_url``. For each prompt and start_len, verifies that: - ``len(input_token_logprobs) == prompt_tokens - logprob_start_len`` - ``len(input_top_logprobs)`` matches - ``len(output_token_logprobs) == max_new_tokens`` """ if prompts is None: prompts = list(CROSS_MODE_PROMPTS) + target_url = test_case.base_url for prompt in prompts: - # Probe to get prompt_tokens probe_res = generate_with_logprobs( target_url, prompt, @@ -513,9 +511,91 @@ def run_logprob_start_len_check( ) +def run_logprob_check(test_case: unittest.TestCase, arg): + """Verify logprob shape correctness for a single parameter combination. + + Args: + test_case: ``unittest.TestCase`` instance (must have ``self.base_url``). + arg: Tuple of (input_len, output_len, temperature, logprob_start_len, + return_logprob, top_logprobs_num). + """ + ( + input_len, + output_len, + temperature, + logprob_start_len, + return_logprob, + top_logprobs_num, + ) = arg + input_ids = list(range(input_len)) + + response = requests.post( + test_case.base_url + "/generate", + json={ + "input_ids": input_ids, + "sampling_params": { + "temperature": temperature, + "max_new_tokens": output_len, + "ignore_eos": True, + }, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + }, + ) + res = response.json() + + test_case.assertEqual(res["meta_info"]["prompt_tokens"], input_len) + test_case.assertEqual(res["meta_info"]["completion_tokens"], output_len) + + if return_logprob: + test_case.assertEqual( + len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len, + res["meta_info"]["prompt_tokens"], + ) + test_case.assertEqual( + len(res["meta_info"]["output_token_logprobs"]), output_len + ) + + if top_logprobs_num: + test_case.assertEqual( + len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len, + res["meta_info"]["prompt_tokens"], + ) + test_case.assertEqual( + len(res["meta_info"]["output_top_logprobs"]), output_len + ) + + for i in range(output_len): + test_case.assertEqual( + len(res["meta_info"]["output_top_logprobs"][i]), + top_logprobs_num, + ) + + if temperature == 0: + rank = 0 + while rank < len(res["meta_info"]["output_top_logprobs"][i]): + try: + test_case.assertListEqual( + res["meta_info"]["output_token_logprobs"][i], + res["meta_info"]["output_top_logprobs"][i][rank], + ) + break + except AssertionError: + # Tie-breaking: allow next rank when values match. + if ( + res["meta_info"]["output_top_logprobs"][i][rank][0] + == res["meta_info"]["output_top_logprobs"][i][rank + 1][ + 0 + ] + ): + rank += 1 + else: + raise + + def run_logprob_mixed_check( test_case, - base_url, input_lens=None, output_lens=None, logprob_start_lens=None, @@ -523,13 +603,13 @@ def run_logprob_mixed_check( ): """Stress-test logprob shape correctness with many parameter combinations. - Sends concurrent requests with various (input_len, output_len, - logprob_start_len, return_logprob, top_logprobs_num) combos and verifies - that the returned array lengths are all correct. + Uses ``test_case.base_url``. Sends concurrent requests with various + (input_len, output_len, logprob_start_len, return_logprob, + top_logprobs_num) combos and verifies that the returned array lengths + are all correct. Args: - test_case: ``unittest.TestCase`` instance. - base_url: Server URL. + test_case: ``unittest.TestCase`` instance (must have ``base_url``). input_lens: List of input lengths to test. output_lens: List of output lengths to test. logprob_start_lens: List of logprob_start_len values to test. @@ -599,7 +679,6 @@ def test_cross_mode_logprobs(self): print(f"Testing cross-mode logprobs for {self.base_url}") run_logprob_cross_mode_check( self, - self.base_url, prompts=self.logprob_prompts, max_new_tokens=self.logprob_max_new_tokens, top_logprobs_num=self.logprob_top_k, @@ -611,7 +690,6 @@ def test_cross_mode_logprob_start_len(self): """Verify logprob_start_len boundary behaviour.""" run_logprob_start_len_check( self, - self.base_url, prompts=self.logprob_prompts, ) @@ -619,7 +697,6 @@ def test_cross_mode_return_text_in_logprobs(self): """Verify return_text_in_logprobs structural correctness.""" run_logprob_cross_mode_check( self, - self.base_url, prompts=(self.logprob_prompts or CROSS_MODE_PROMPTS)[:1], max_new_tokens=8, return_text_in_logprobs=True, @@ -628,4 +705,4 @@ def test_cross_mode_return_text_in_logprobs(self): def test_cross_mode_logprob_mixed(self): """Stress-test logprob shape correctness with many parameter combos.""" - run_logprob_mixed_check(self, self.base_url) + run_logprob_mixed_check(self) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 064626d67303..7dee89808ab8 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1843,81 +1843,6 @@ def write_github_step_summary(content): f.write(content) -def run_logprob_check(self: unittest.TestCase, arg: Tuple): - ( - input_len, - output_len, - temperature, - logprob_start_len, - return_logprob, - top_logprobs_num, - ) = arg - input_ids = list(range(input_len)) - - response = requests.post( - self.base_url + "/generate", - json={ - "input_ids": input_ids, - "sampling_params": { - "temperature": temperature, - "max_new_tokens": output_len, - "ignore_eos": True, - }, - "return_logprob": return_logprob, - "logprob_start_len": logprob_start_len, - "top_logprobs_num": top_logprobs_num, - }, - ) - response_json = response.json() - - res = response_json - self.assertEqual(res["meta_info"]["prompt_tokens"], input_len) - self.assertEqual(res["meta_info"]["completion_tokens"], output_len) - - # Test the number of tokens are correct - if return_logprob: - self.assertEqual( - len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len, - res["meta_info"]["prompt_tokens"], - ) - self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len) - - if top_logprobs_num: - self.assertEqual( - len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len, - res["meta_info"]["prompt_tokens"], - ) - self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len) - - for i in range(output_len): - self.assertEqual( - len(res["meta_info"]["output_top_logprobs"][i]), - top_logprobs_num, - ) - - # Test the top-1 tokens are the same as output tokens if temperature == 0 - if temperature == 0: - rank = 0 - while rank < len(res["meta_info"]["output_top_logprobs"][i]): - try: - self.assertListEqual( - res["meta_info"]["output_token_logprobs"][i], - res["meta_info"]["output_top_logprobs"][i][rank], - ) - break - except AssertionError: - # There's a tie. Allow the second item in this case. - if ( - res["meta_info"]["output_top_logprobs"][i][rank][0] - == res["meta_info"]["output_top_logprobs"][i][rank + 1][ - 0 - ] - ): - rank += 1 - else: - raise - - def send_generate_requests(base_url: str, num_requests: int) -> List[str]: """Sends generate request serially and returns status codes. Max concurrency is 1.""" diff --git a/test/registered/core/test_srt_endpoint.py b/test/registered/core/test_srt_endpoint.py index 92aeff84cf00..f3c7f777ceac 100644 --- a/test/registered/core/test_srt_endpoint.py +++ b/test/registered/core/test_srt_endpoint.py @@ -18,13 +18,13 @@ from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.kits.logprob_kit import run_logprob_check from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_server, - run_logprob_check, ) register_cuda_ci(est_time=127, suite="stage-b-test-small-1-gpu") From 8a3e50567600e041690a032a9fbbfc2d8523dc52 Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Thu, 19 Mar 2026 09:00:14 +0000 Subject: [PATCH 12/12] upd --- .../scheduler_output_processor_mixin.py | 6 ++++-- python/sglang/test/kits/logprob_kit.py | 20 +++++-------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 2c00f187a499..4abd73ff158b 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -1084,8 +1084,10 @@ def stream_output_generation( send_output_token_logprobs_offset:logprob_end ] ) - req.send_output_token_logprobs_offset = logprob_end or len( - req.output_token_logprobs_val + req.send_output_token_logprobs_offset = ( + logprob_end + if logprob_end is not None + else len(req.output_token_logprobs_val) ) else: output_token_logprobs_val.append([]) diff --git a/python/sglang/test/kits/logprob_kit.py b/python/sglang/test/kits/logprob_kit.py index 12f8bc66c660..a72b2e05ba09 100644 --- a/python/sglang/test/kits/logprob_kit.py +++ b/python/sglang/test/kits/logprob_kit.py @@ -296,20 +296,15 @@ def run_logprob_cross_mode_check( token_ids_logprob=None, return_text_in_logprobs=False, decimal_places=DEFAULT_DECIMAL_PLACES, - baseline_url=None, ): """Run cross-mode logprob comparison across all artifacts. 1. Generate tokens from ``test_case.base_url`` (with logprob_start_len=0). - 2. Score the full sequence via prefill on the baseline server. + 2. Score the same full sequence via prefill on the same server + (``max_new_tokens=0``), so speculative decoding is not involved. 3. Compare output_token_logprobs, input_token_logprobs, output_top_logprobs, input_top_logprobs, output_token_ids_logprobs, input_token_ids_logprobs. - When *baseline_url* is ``None``, the target server itself is used for the - scoring step. Because scoring uses ``max_new_tokens=0`` (pure prefill), - speculative decoding is not involved, making the target server a valid - non-speculative baseline for its own decode logprobs. - Args: test_case: ``unittest.TestCase`` instance (must have ``base_url``). prompts: List of prompt strings. @@ -318,11 +313,8 @@ def run_logprob_cross_mode_check( token_ids_logprob: Token IDs for the token_ids_logprob artifact. return_text_in_logprobs: Whether to include token text. decimal_places: ``assertAlmostEqual`` precision (``places``). - baseline_url: URL of the baseline server. Defaults to ``test_case.base_url``. """ - target_url = test_case.base_url - if baseline_url is None: - baseline_url = target_url + base_url = test_case.base_url if prompts is None: prompts = list(CROSS_MODE_PROMPTS) if token_ids_logprob is None: @@ -332,9 +324,8 @@ def run_logprob_cross_mode_check( tag_prefix = f"round {round_idx}" print(f"\n--- Cross-mode check {tag_prefix}: {prompt!r} ---") - # Step 1: generate from target with logprob_start_len=0 gen_res = generate_with_logprobs( - target_url, + base_url, prompt, max_new_tokens, top_logprobs_num, @@ -349,9 +340,8 @@ def run_logprob_cross_mode_check( output_token_ids = [t[1] for t in meta["output_token_logprobs"]] full_sequence = input_token_ids + output_token_ids - # Step 2: score the full sequence via prefill on baseline score_res = generate_with_logprobs( - baseline_url, + base_url, full_sequence, max_new_tokens=0, top_logprobs_num=top_logprobs_num,