diff --git a/python/sglang/test/kits/eval_accuracy_kit.py b/python/sglang/test/kits/eval_accuracy_kit.py index 25bf58151c9f..9757dc01523e 100644 --- a/python/sglang/test/kits/eval_accuracy_kit.py +++ b/python/sglang/test/kits/eval_accuracy_kit.py @@ -9,12 +9,16 @@ _THRESHOLD_NOT_SET = float("nan") -def _check_accept_length(test_case, base_url, threshold): - """Check speculative decoding accept length from server info.""" - server_info = requests.get(base_url + "/get_server_info").json() - avg_spec_accept_length = server_info["internal_states"][0]["avg_spec_accept_length"] - print(f"{avg_spec_accept_length=}") - test_case.assertGreater(avg_spec_accept_length, threshold) +def _check_accept_length(test_case, base_url, threshold=None): + """Print accept length; optionally assert it exceeds threshold.""" + try: + server_info = requests.get(base_url + "/server_info").json() + val = server_info["internal_states"][0]["avg_spec_accept_length"] + except (KeyError, IndexError, requests.RequestException): + return + print(f"avg_spec_accept_length={val:.4f}") + if threshold is not None: + test_case.assertGreater(val, threshold) class GSM8KMixin: @@ -57,8 +61,7 @@ def test_gsm8k(self): self.assertGreaterEqual(metrics["score"], self.gsm8k_accuracy_thres) - if self.gsm8k_accept_length_thres is not None: - _check_accept_length(self, self.base_url, self.gsm8k_accept_length_thres) + _check_accept_length(self, self.base_url, self.gsm8k_accept_length_thres) class MMLUMixin: @@ -95,8 +98,7 @@ def test_mmlu(self): self.assertGreaterEqual(metrics["score"], self.mmlu_score_threshold) - if self.mmlu_accept_length_thres is not None: - _check_accept_length(self, self.base_url, self.mmlu_accept_length_thres) + _check_accept_length(self, self.base_url, self.mmlu_accept_length_thres) class HumanEvalMixin: @@ -136,6 +138,8 @@ def test_human_eval(self): self.assertGreaterEqual(metrics["score"], threshold) + _check_accept_length(self, self.base_url) + class MGSMEnMixin: """Mixin for MGSM English evaluation. @@ -169,3 +173,5 @@ def test_mgsm_en(self): write_github_step_summary(f"### test_mgsm_en\n{metrics['score']=:.4f}\n") self.assertGreaterEqual(metrics["score"], self.mgsm_en_score_threshold) + + _check_accept_length(self, self.base_url) diff --git a/test/registered/spec/test_ngram_speculative_decoding.py b/test/registered/spec/test_ngram_speculative_decoding.py index f80b1e646dea..d8e0c467b6b4 100644 --- a/test/registered/spec/test_ngram_speculative_decoding.py +++ b/test/registered/spec/test_ngram_speculative_decoding.py @@ -111,7 +111,7 @@ def generate_batch(): return outputs def get_accept_length(): - info = requests.get(self.base_url + "/get_server_info").json() + info = requests.get(self.base_url + "/server_info").json() return info["internal_states"][0]["avg_spec_accept_length"] # Phase 1: baseline — no SAM corpus loaded, only trie