diff --git a/python/sglang/test/scripted_runtime_chunked_helpers.py b/python/sglang/test/scripted_runtime_chunked_helpers.py new file mode 100644 index 000000000000..a5b0fac15e06 --- /dev/null +++ b/python/sglang/test/scripted_runtime_chunked_helpers.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import Any, Dict, List + +DEFAULT_CHUNK_SIZE: int = 256 + +DEFAULT_MAX_STEPS: int = 400 + +VERY_LONG_PROMPT_LEN: int = 8 * DEFAULT_CHUNK_SIZE + +SMALL_MODEL: str = "Qwen/Qwen3-0.6B" + + +def base_engine_kwargs( + *, + model_path: str = SMALL_MODEL, + chunked_prefill_size: int = DEFAULT_CHUNK_SIZE, + **overrides: Any, +) -> Dict[str, Any]: + kwargs: Dict[str, Any] = dict( + model_path=model_path, + chunked_prefill_size=chunked_prefill_size, + ) + kwargs.update(overrides) + return kwargs + + +def run_until(handle, predicate, *, max_steps: int = DEFAULT_MAX_STEPS): + for _ in range(max_steps): + if predicate(handle): + return + yield + raise AssertionError( + f"run_until: predicate never satisfied after {max_steps} steps " + f"(handle rid={handle.rid!r}, finished={handle.finished})" + ) + + +def run_until_finished(handle, *, max_steps: int = DEFAULT_MAX_STEPS): + yield from run_until(handle, lambda h: h.finished, max_steps=max_steps) + + +def run_until_all_finished(handles: List[Any], *, max_steps: int = DEFAULT_MAX_STEPS): + for _ in range(max_steps): + if all(h.finished for h in handles): + return + yield + raise AssertionError( + f"run_until_all_finished: not all reqs finished after {max_steps} " + f"steps (finished={[h.finished for h in handles]})" + ) + + +def warmup_radix(t, prompt_tokens: List[int], *, max_steps: int = DEFAULT_MAX_STEPS): + assert prompt_tokens, "warmup_radix needs a non-empty prompt" + token = prompt_tokens[0] + assert all( + x == token for x in prompt_tokens + ), "warmup_radix supports only uniform prompts" + handle = t.start_req( + prompt_len=len(prompt_tokens), max_new_tokens=1, prompt_token=token + ) + yield from run_until_finished(handle, max_steps=max_steps) + + +BALLAST_MAX_NEW_TOKENS: int = 30000 + + +def exhaust_row_pool(t, *, leave_rows: int, max_steps: int = DEFAULT_MAX_STEPS): + target: int = t.scheduler.req_to_token_pool.available_size() - leave_rows + if target <= 0: + return + + for _ in range(target): + t.start_req( + prompt_len=1, max_new_tokens=BALLAST_MAX_NEW_TOKENS, ignore_eos=True + ) + + for _ in range(max_steps): + if t.scheduler.req_to_token_pool.available_size() <= leave_rows: + return + yield + raise AssertionError( + f"exhaust_row_pool: ballast reqs never filled the row pool down to " + f"leave_rows={leave_rows} after {max_steps} steps " + f"(available_size={t.scheduler.req_to_token_pool.available_size()})" + ) + + +LIFECYCLE_STAGES = ( + "first_chunk", + "last_chunk", + "first_decode", + "mid_decode", + "last_decode", +) + + +def advance_to_nth_chunk(r, target_chunk: int, *, max_steps: int = DEFAULT_MAX_STEPS): + # Drive until the hook has recorded `target_chunk` chunked-prefill batches. + # chunks_done is accumulated from on_run_batch (every forward batch), so it + # never misses a chunk the way sampling the instantaneous is_chunking flag + # once per yield can: on the step the req leaves chunked_req, is_chunking is + # already False, so `seen` undercounted and the req could race to completion + # on slower CI before the loop caught up. + for _ in range(max_steps): + assert not r.finished, f"req finished before reaching chunk {target_chunk}" + if r.chunks_done >= target_chunk: + return + yield + raise AssertionError( + f"never reached chunk {target_chunk} (chunks_done={r.chunks_done})" + ) + + +def advance_to_decode_step( + r, target_output_len: int, *, max_steps: int = DEFAULT_MAX_STEPS +): + for _ in range(max_steps): + assert ( + not r.finished + ), f"req finished before reaching decode step {target_output_len}" + req = r.req + if req is not None and len(req.output_ids) >= target_output_len: + return + yield + raise AssertionError(f"never reached decode step {target_output_len}") + + +def advance_to_lifecycle_stage( + r, + stage: str, + *, + num_middle_chunks: int, + max_new_tokens: int, + max_steps: int = DEFAULT_MAX_STEPS, +): + if stage == "first_chunk": + yield from advance_to_nth_chunk(r, 1, max_steps=max_steps) + elif stage == "last_chunk": + yield from advance_to_nth_chunk(r, num_middle_chunks, max_steps=max_steps) + elif stage == "first_decode": + yield from advance_to_decode_step(r, 1, max_steps=max_steps) + elif stage == "mid_decode": + yield from advance_to_decode_step( + r, max(1, max_new_tokens // 2), max_steps=max_steps + ) + elif stage == "last_decode": + yield from advance_to_decode_step(r, max_new_tokens - 1, max_steps=max_steps) + else: + raise AssertionError(f"unknown lifecycle stage {stage!r}") diff --git a/test/registered/chunked_prefill/test_scripted_core_1gpu.py b/test/registered/chunked_prefill/test_scripted_core_1gpu.py new file mode 100644 index 000000000000..c6f36d29a827 --- /dev/null +++ b/test/registered/chunked_prefill/test_scripted_core_1gpu.py @@ -0,0 +1,188 @@ +import unittest + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.scripted_runtime.context import ScriptedContext +from sglang.test.scripted_runtime.test_case import ScriptedTestCase +from sglang.test.scripted_runtime_chunked_helpers import ( + LIFECYCLE_STAGES, + advance_to_lifecycle_stage, + base_engine_kwargs, + run_until_finished, +) + +register_cuda_ci(est_time=300, stage="extra-a", runner_config="1-gpu-small") + + +_CHUNK_SIZE = 64 +_PROMPT_LEN = 4 * _CHUNK_SIZE - 3 + +_NUM_MIDDLE_CHUNKS = (_PROMPT_LEN - 1) // _CHUNK_SIZE +_LIFECYCLE_MAX_NEW_TOKENS = 4 + + +def _advance_to_stage(r, stage: str): + yield from advance_to_lifecycle_stage( + r, + stage, + num_middle_chunks=_NUM_MIDDLE_CHUNKS, + max_new_tokens=_LIFECYCLE_MAX_NEW_TOKENS, + ) + + +class TestScriptedCore(ScriptedTestCase): + ENGINE_KWARGS = base_engine_kwargs(chunked_prefill_size=_CHUNK_SIZE) + + def test_chunked_prefill_smoke(self): + self.server.execute_script(self._script_chunked_prefill_smoke) + + @staticmethod + def _script_chunked_prefill_smoke(t: ScriptedContext): + r = t.start_req(prompt_len=_PROMPT_LEN, max_new_tokens=3) + yield from run_until_finished(r) + assert r.finished, "req did not finish" + + def test_chunked_prefill_smoke_at_chunk_boundary_offsets(self): + for offset in (-2, -1, 1, 2): + prompt_len = 2 * _CHUNK_SIZE + offset + with self.subTest(offset=offset, prompt_len=prompt_len): + self.server.execute_script( + self._script_chunked_prefill_smoke_at_offset, + args=(prompt_len,), + ) + + @staticmethod + def _script_chunked_prefill_smoke_at_offset(t: ScriptedContext, prompt_len: int): + r = t.start_req(prompt_len=prompt_len, max_new_tokens=3) + yield from run_until_finished(r) + assert r.finished, f"req with prompt_len={prompt_len} did not finish" + + def test_pause_retract_at_lifecycle_points_then_resume(self): + for stage in LIFECYCLE_STAGES: + with self.subTest(stage=stage): + self.server.execute_script( + self._script_pause_retract_at_stage, + args=(stage,), + ) + + @staticmethod + def _script_pause_retract_at_stage(t: ScriptedContext, stage: str): + r = t.start_req( + prompt_len=_PROMPT_LEN, max_new_tokens=_LIFECYCLE_MAX_NEW_TOKENS + ) + yield from _advance_to_stage(r, stage) + + assert r.req is not None, f"stage={stage}: req vanished before pause" + + t.pause_generation(mode="retract") + yield + + # At the last_decode stage the final decode can complete during the + # retract; a finished req is removed from the scheduler, so its + # output_ids are no longer observable through the harness. That case's + # only observable consequence — clean completion — is covered by the + # run_until_finished tail below. When the req is not finished, + # pause(retract) must park it back in the waiting_queue and the paused + # engine must not advance it. + if not r.finished: + req = r.req + assert req is not None and req in t.scheduler.waiting_queue, ( + f"stage={stage}: pause(retract) should park the req back in " + f"waiting_queue; found={req!r}" + ) + output_tokens_after_pause = len(req.output_ids) + for _ in range(3): + yield + req = r.req + assert ( + req is not None and len(req.output_ids) == output_tokens_after_pause + ), ( + f"stage={stage}: paused engine advanced the req " + f"({len(req.output_ids) if req is not None else None} output " + f"tokens, expected {output_tokens_after_pause})" + ) + + t.continue_generation() + yield from run_until_finished(r) + assert r.finished, f"stage={stage}: req did not finish after pause/continue" + + def test_abort_all_at_lifecycle_points(self): + for stage in LIFECYCLE_STAGES: + with self.subTest(stage=stage): + self.server.execute_script( + self._script_abort_all_at_stage, args=(stage,) + ) + + @staticmethod + def _script_abort_all_at_stage(t: ScriptedContext, stage: str): + r = t.start_req( + prompt_len=_PROMPT_LEN, max_new_tokens=_LIFECYCLE_MAX_NEW_TOKENS + ) + yield from _advance_to_stage(r, stage) + + t.abort_all() + for _ in range(8): + yield + if r.finished: + break + + assert r.finished, f"stage={stage}: req did not finish after abort_all" + + def test_chunked_req_single_decode_finishes(self): + self.server.execute_script(self._script_chunked_req_single_decode_finishes) + + @staticmethod + def _script_chunked_req_single_decode_finishes(t: ScriptedContext): + r = t.start_req(prompt_len=_PROMPT_LEN, max_new_tokens=1) + yield from run_until_finished(r) + assert r.finished, "single-decode chunked req did not finish" + + def test_chunked_prefill_radix_hit_count(self): + self.server.execute_script(self._script_chunked_prefill_radix_hit_count) + + @staticmethod + def _script_chunked_prefill_radix_hit_count(t: ScriptedContext): + r = t.start_req(prompt_len=_PROMPT_LEN, max_new_tokens=2) + yield from run_until_finished(r) + assert r.finished + _assert_prefill_twice_decode_once(t, prompt_len=_PROMPT_LEN) + + def test_nonchunked_prefill_radix_hit_count(self): + self.server.execute_script(self._script_nonchunked_prefill_radix_hit_count) + + @staticmethod + def _script_nonchunked_prefill_radix_hit_count(t: ScriptedContext): + prompt_len = _CHUNK_SIZE - 20 + r = t.start_req(prompt_len=prompt_len, max_new_tokens=2) + yield from run_until_finished(r) + assert r.finished + _assert_prefill_twice_decode_once(t, prompt_len=prompt_len) + + +def _assert_prefill_twice_decode_once(t: ScriptedContext, *, prompt_len: int) -> None: + root = t.scheduler.tree_cache.root_node + prefill_hits: list[int] = [] + decode_hits: list[int] = [] + stack = [(child, len(child.key)) for child in root.children.values()] + while stack: + node, end_index = stack.pop() + bucket = prefill_hits if end_index <= prompt_len else decode_hits + bucket.append(node.hit_count) + for child in node.children.values(): + stack.append((child, end_index + len(child.key))) + + assert prefill_hits and decode_hits, ( + f"expected both prefill and decode radix nodes; " + f"prefill={prefill_hits}, decode={decode_hits}, prompt_len={prompt_len}" + ) + assert all(h == 2 for h in prefill_hits), ( + f"each prefill node must be hit exactly twice; " + f"prefill={prefill_hits}, decode={decode_hits}" + ) + assert all(h == 1 for h in decode_hits), ( + f"each decode node must be hit exactly once; " + f"prefill={prefill_hits}, decode={decode_hits}" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/chunked_prefill/test_scripted_core_4gpu.py b/test/registered/chunked_prefill/test_scripted_core_4gpu.py new file mode 100644 index 000000000000..2ff60d749563 --- /dev/null +++ b/test/registered/chunked_prefill/test_scripted_core_4gpu.py @@ -0,0 +1,60 @@ +import unittest + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.scripted_runtime.context import ScriptedContext +from sglang.test.scripted_runtime.test_case import ScriptedTestCase +from sglang.test.scripted_runtime_chunked_helpers import ( + DEFAULT_MAX_STEPS, + SMALL_MODEL, + base_engine_kwargs, + run_until_all_finished, +) + +register_cuda_ci(est_time=900, stage="extra-b", runner_config="4-gpu-h100") + + +_CHUNK_SIZE = 64 + + +class TestScriptedPpChunkSweep(ScriptedTestCase): + ENGINE_KWARGS = base_engine_kwargs( + model_path=SMALL_MODEL, + tp_size=1, + dp_size=1, + pp_size=4, + pp_async_batch_depth=2, + chunked_prefill_size=_CHUNK_SIZE, + ) + + _PP_LOOP_SIZE = ENGINE_KWARGS["pp_size"] + ENGINE_KWARGS["pp_async_batch_depth"] + assert _PP_LOOP_SIZE == 6 + + _NUM_CHUNKS_VALUES = (2, 4, 6, 8) + _NUM_CONC_REQS_VALUES = (1, 2, 4) + + def test_pp_chunk_sweep(self): + for num_chunks in self._NUM_CHUNKS_VALUES: + for num_conc_reqs in self._NUM_CONC_REQS_VALUES: + with self.subTest(num_chunks=num_chunks, num_conc_reqs=num_conc_reqs): + self.server.execute_script( + self._script_pp_one_combo, + args=(num_chunks, num_conc_reqs), + ) + + @staticmethod + def _script_pp_one_combo(t: ScriptedContext, num_chunks: int, num_conc_reqs: int): + prompt_len = num_chunks * _CHUNK_SIZE - 3 + reqs = [ + t.start_req(prompt_len=prompt_len, max_new_tokens=2) + for _ in range(num_conc_reqs) + ] + yield from run_until_all_finished(reqs, max_steps=DEFAULT_MAX_STEPS) + for r in reqs: + assert r.finished, ( + f"combo num_chunks={num_chunks}, num_conc_reqs={num_conc_reqs}: " + f"req {r.rid!r} did not finish" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/chunked_prefill/test_scripted_swa_1gpu.py b/test/registered/chunked_prefill/test_scripted_swa_1gpu.py new file mode 100644 index 000000000000..3a7468260bb7 --- /dev/null +++ b/test/registered/chunked_prefill/test_scripted_swa_1gpu.py @@ -0,0 +1,107 @@ +import unittest + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.scripted_runtime.context import ScriptedContext +from sglang.test.scripted_runtime.test_case import ScriptedTestCase +from sglang.test.scripted_runtime_chunked_helpers import base_engine_kwargs + +register_cuda_ci(est_time=400, stage="extra-a", runner_config="1-gpu-large") + + +_SWA_MODEL = "openai/gpt-oss-20b" + +_MAX_TOTAL_TOKENS = 4096 +_SWA_FULL_TOKENS_RATIO = 0.1 +_CHUNK_SIZE = 64 + +_N_DECODERS = 6 +_DECODER_PROMPT = 64 +_DECODER_MAX_NEW = 512 +_DECODER_WARMUP_RUNNING = 3 + +_CHUNKED_PROMPT = 384 +_CHUNKED_MAX_NEW = 2 +_N_CANDIDATES = 24 +_STEPS_PER_CANDIDATE = 120 +_DECODER_WARMUP_STEPS = 60 +_DRAIN_STEPS = 400 + + +class TestScriptedSwaChunkedReqEarlyReturn(ScriptedTestCase): + ENGINE_KWARGS = base_engine_kwargs( + model_path=_SWA_MODEL, + chunked_prefill_size=_CHUNK_SIZE, + max_total_tokens=_MAX_TOTAL_TOKENS, + swa_full_tokens_ratio=_SWA_FULL_TOKENS_RATIO, + page_size=1, + mem_fraction_static=0.70, + ) + + def test_swa_chunked_req_early_return_no_double_free(self): + self.server.execute_script( + self._script_swa_chunked_req_early_return_no_double_free + ) + + @staticmethod + def _script_swa_chunked_req_early_return_no_double_free(t: ScriptedContext): + s = t.scheduler + + for i in range(_N_DECODERS): + t.start_req( + prompt_len=_DECODER_PROMPT, + max_new_tokens=_DECODER_MAX_NEW, + ignore_eos=True, + prompt_token=10 + i, + ) + for _ in range(_DECODER_WARMUP_STEPS): + if len(s.running_batch.reqs) >= _DECODER_WARMUP_RUNNING: + break + yield + + candidates = [] + parked = False + for _ in range(_N_CANDIDATES): + candidates.append( + t.start_req( + prompt_len=_CHUNKED_PROMPT, + max_new_tokens=_CHUNKED_MAX_NEW, + prompt_token=2, + ) + ) + for _ in range(_STEPS_PER_CANDIDATE): + if any(t.chunked_parks(c.rid) > 0 for c in candidates): + parked = True + break + if candidates[-1].finished: + break + yield + if parked: + break + + parked = parked or any(t.chunked_parks(c.rid) > 0 for c in candidates) + assert parked, ( + "no chunked candidate was ever parked by add_chunked_req's hybrid-SWA " + "early-return; the test never exercised the stash gate" + ) + + t.abort_all() + for _ in range(_DRAIN_STEPS): + if ( + s.chunked_req is None + and len(s.waiting_queue) == 0 + and s.running_batch.is_empty() + ): + break + yield + for _ in range(20): + yield + + locked = {nid: lr for nid, lr in t.get_all_node_lock_refs().items() if lr != 0} + assert not locked, ( + f"radix nodes left locked after drain {locked} -- stash gate let an " + "un-scheduled chunked req commit partial KV" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/scripted_runtime/test_scripted_runtime_core.py b/test/registered/scripted_runtime/test_scripted_runtime_core.py new file mode 100644 index 000000000000..5bafe521d872 --- /dev/null +++ b/test/registered/scripted_runtime/test_scripted_runtime_core.py @@ -0,0 +1,804 @@ +import unittest + +from sglang.srt.managers.schedule_batch import FINISH_ABORT +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.scripted_runtime.context import ScriptedContext +from sglang.test.scripted_runtime.http_server import ScriptedHttpServer +from sglang.test.scripted_runtime.req_handle import ScriptedReqHandle +from sglang.test.scripted_runtime.test_case import ScriptedTestCase +from sglang.test.scripted_runtime_chunked_helpers import ( + advance_to_decode_step, + advance_to_nth_chunk, + base_engine_kwargs, + exhaust_row_pool, + run_until_finished, + warmup_radix, +) +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=460, stage="base-b", runner_config="1-gpu-small") + + +_CHUNK_SIZE = 64 +_LONG_PROMPT_LEN = 4 * _CHUNK_SIZE - 3 +_SHORT_PROMPT_LEN = 16 +_DECODE_MAX_NEW_TOKENS = 8 + +_ENGINE_KWARGS = base_engine_kwargs(chunked_prefill_size=_CHUNK_SIZE) + + +def _script_noop(t: ScriptedContext): + yield + + +class TestScriptedRuntimeCore(ScriptedTestCase): + ENGINE_KWARGS = _ENGINE_KWARGS + + def test_start_req_auto_rid_and_finishes(self): + self.server.execute_script(self._script_start_req_auto_rid_and_finishes) + + @staticmethod + def _script_start_req_auto_rid_and_finishes(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=4) + assert r.rid.startswith("scripted-"), f"unexpected auto rid {r.rid!r}" + yield from run_until_finished(r) + assert r.finished, "auto-rid req did not finish" + + def test_start_req_explicit_rid(self): + self.server.execute_script(self._script_start_req_explicit_rid) + + @staticmethod + def _script_start_req_explicit_rid(t: ScriptedContext): + r = t.start_req( + prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=2, rid="explicit-rid-test" + ) + assert r.rid == "explicit-rid-test", f"explicit rid not honored: {r.rid!r}" + yield + assert ( + t.find_req_by_rid("explicit-rid-test") is not None + ), "explicit rid not visible to the scheduler after one step" + + def test_find_req_by_rid_hit_and_miss(self): + self.server.execute_script(self._script_find_req_by_rid_hit_and_miss) + + @staticmethod + def _script_find_req_by_rid_hit_and_miss(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=4) + yield + found = t.find_req_by_rid(r.rid) + assert ( + found is not None and found.rid == r.rid + ), f"find_req_by_rid missed the live rid {r.rid!r}" + assert ( + t.find_req_by_rid("no-such-rid") is None + ), "find_req_by_rid returned a req for an unknown rid" + yield from run_until_finished(r) + + def test_is_finished_reflects_completion(self): + self.server.execute_script(self._script_is_finished_reflects_completion) + + @staticmethod + def _script_is_finished_reflects_completion(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=4) + yield + assert not t.is_finished(r.rid), "is_finished True while the req still runs" + yield from run_until_finished(r) + assert t.is_finished(r.rid), "is_finished False after the req completed" + + def test_req_handle_req_property(self): + self.server.execute_script(self._script_req_handle_req_property) + + @staticmethod + def _script_req_handle_req_property(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=4) + yield + assert ( + r.req is not None and r.req.rid == r.rid + ), f"handle.req did not resolve to the live req for {r.rid!r}" + bogus = ScriptedReqHandle(rid="no-such-rid", context=t) + assert bogus.req is None, "handle.req returned a req for an unknown rid" + yield from run_until_finished(r) + + def test_req_handle_finished_property(self): + self.server.execute_script(self._script_req_handle_finished_property) + + @staticmethod + def _script_req_handle_finished_property(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=4) + yield + assert not r.finished, "handle.finished True while the req still runs" + yield from run_until_finished(r) + assert r.finished, "handle.finished False after the req completed" + + def test_is_chunking_true_mid_prefill_false_after(self): + self.server.execute_script( + self._script_is_chunking_true_mid_prefill_false_after + ) + + @staticmethod + def _script_is_chunking_true_mid_prefill_false_after(t: ScriptedContext): + r = t.start_req(prompt_len=_LONG_PROMPT_LEN, max_new_tokens=2) + yield from advance_to_nth_chunk(r, 1) + assert r.is_chunking, "handle.is_chunking False during multi-chunk prefill" + assert t.is_chunking( + r.rid + ), "context.is_chunking False during multi-chunk prefill" + assert not t.is_chunking( + "no-such-rid" + ), "context.is_chunking True for an unknown rid" + yield from run_until_finished(r) + assert not r.is_chunking, "handle.is_chunking still True after finish" + assert not t.is_chunking(r.rid), "context.is_chunking still True after finish" + + def test_pause_retract_parks_in_waiting_queue_then_resumes(self): + self.server.execute_script(self._script_pause_retract_parks_then_resumes) + + @staticmethod + def _script_pause_retract_parks_then_resumes(t: ScriptedContext): + r = t.start_req( + prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=_DECODE_MAX_NEW_TOKENS + ) + yield from advance_to_decode_step(r, 1) + assert r.req is not None, "req vanished before pause(retract)" + + t.pause_generation(mode="retract") + yield + + req = r.req + assert ( + req is not None and req in t.scheduler.waiting_queue + ), f"pause(retract) did not park the req in waiting_queue; found {req!r}" + + frozen = len(req.output_ids) + for _ in range(3): + yield + req = r.req + assert req is not None and len(req.output_ids) == frozen, ( + f"paused(retract) engine advanced the req: " + f"{len(req.output_ids) if req is not None else None} != {frozen}" + ) + + t.continue_generation() + yield from run_until_finished(r) + assert r.finished, "req did not finish after pause(retract)/continue" + + def test_pause_in_place_freezes_then_resumes(self): + self.server.execute_script(self._script_pause_in_place_freezes_then_resumes) + + @staticmethod + def _script_pause_in_place_freezes_then_resumes(t: ScriptedContext): + r = t.start_req( + prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=_DECODE_MAX_NEW_TOKENS + ) + yield from advance_to_decode_step(r, 1) + req = r.req + assert req is not None, "req vanished before pause(in_place)" + frozen = len(req.output_ids) + + t.pause_generation(mode="in_place") + yield + + req = r.req + assert ( + req is not None and req not in t.scheduler.waiting_queue + ), f"pause(in_place) should not retract the req to waiting_queue; found {req!r}" + + for _ in range(3): + yield + req = r.req + assert req is not None and len(req.output_ids) == frozen, ( + f"paused(in_place) engine advanced the req: " + f"{len(req.output_ids) if req is not None else None} != {frozen}" + ) + + t.continue_generation() + yield from run_until_finished(r) + assert r.finished, "req did not finish after pause(in_place)/continue" + + def test_continue_generation_with_torch_empty_cache(self): + self.server.execute_script( + self._script_continue_generation_with_torch_empty_cache + ) + + @staticmethod + def _script_continue_generation_with_torch_empty_cache(t: ScriptedContext): + r = t.start_req( + prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=_DECODE_MAX_NEW_TOKENS + ) + yield from advance_to_decode_step(r, 1) + t.pause_generation(mode="retract") + yield + t.continue_generation(torch_empty_cache=True) + yield from run_until_finished(r) + assert r.finished, "req did not finish after continue(torch_empty_cache=True)" + + def test_abort_all_finishes_running_req(self): + self.server.execute_script(self._script_abort_all_finishes_running_req) + + @staticmethod + def _script_abort_all_finishes_running_req(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=64) + yield from advance_to_decode_step(r, 1) + assert not r.finished, "req finished before abort_all could act" + + t.abort_all() + for _ in range(8): + yield + if r.finished: + break + assert r.finished, "req did not finish after abort_all" + + def test_flush_cache_clears_radix_tree(self): + self.server.execute_script(self._script_flush_cache_clears_radix_tree) + + @staticmethod + def _script_flush_cache_clears_radix_tree(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=2) + yield from run_until_finished(r) + for _ in range(3): + yield + assert t.get_all_node_hit_counts(), "expected radix nodes after a finished req" + + t.flush_cache() + yield + assert ( + not t.get_all_node_hit_counts() + ), "flush_cache did not clear the radix tree" + + def test_get_all_node_hit_counts_increments_on_cache_hit(self): + self.server.execute_script(self._script_hit_counts_increment_on_cache_hit) + + @staticmethod + def _script_hit_counts_increment_on_cache_hit(t: ScriptedContext): + t.flush_cache() + yield + + r1 = t.start_req(prompt_len=_LONG_PROMPT_LEN, max_new_tokens=2) + yield from run_until_finished(r1) + for _ in range(3): + yield + counts_before = t.get_all_node_hit_counts() + assert counts_before, "expected radix nodes after the first req" + max_before = max(counts_before.values()) + + r2 = t.start_req(prompt_len=_LONG_PROMPT_LEN, max_new_tokens=2) + yield from run_until_finished(r2) + counts_after = t.get_all_node_hit_counts() + assert max(counts_after.values()) > max_before, ( + f"identical prompt did not bump any node hit_count: " + f"before={max_before} after={max(counts_after.values())}" + ) + + def test_get_all_node_lock_refs_held_during_run_released_after(self): + self.server.execute_script(self._script_lock_refs_held_then_released) + + @staticmethod + def _script_lock_refs_held_then_released(t: ScriptedContext): + t.flush_cache() + yield + + r = t.start_req( + prompt_len=_LONG_PROMPT_LEN, max_new_tokens=_DECODE_MAX_NEW_TOKENS + ) + yield from advance_to_decode_step(r, 1) + lock_refs = t.get_all_node_lock_refs() + assert ( + lock_refs and max(lock_refs.values()) >= 1 + ), f"expected a locked radix node while the req runs; got {lock_refs}" + + yield from run_until_finished(r) + for _ in range(3): + yield + released = t.get_all_node_lock_refs() + assert released and all( + ref == 0 for ref in released.values() + ), f"radix nodes still locked after the req finished: {released}" + + def test_start_req_ignore_eos_runs_full_length(self): + self.server.execute_script(self._script_ignore_eos_runs_full_length) + + @staticmethod + def _script_ignore_eos_runs_full_length(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=6, ignore_eos=True) + yield from run_until_finished(r) + req = r.req + assert req is not None, "finished req vanished before its output could be read" + assert ( + len(req.output_ids) == 6 + ), f"ignore_eos must decode the full length; got {list(req.output_ids)!r}" + + def test_start_req_priority_is_propagated(self): + self.server.execute_script(self._script_priority_is_propagated) + + @staticmethod + def _script_priority_is_propagated(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=4, priority=7) + yield from advance_to_decode_step(r, 1) + req = r.req + assert req is not None and req.priority == 7, ( + f"priority not propagated to the scheduler req; " + f"got {None if req is None else req.priority}" + ) + yield from run_until_finished(r) + + def test_start_req_dp_rank_zero_accepted(self): + self.server.execute_script(self._script_dp_rank_zero_accepted) + + @staticmethod + def _script_dp_rank_zero_accepted(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=4, dp_rank=0) + yield from run_until_finished(r) + assert r.finished, "dp_rank=0 req did not finish" + + def test_abort_single_handle_finishes_with_abort_reason(self): + self.server.execute_script(self._script_abort_single_handle) + + @staticmethod + def _script_abort_single_handle(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=64) + yield from advance_to_decode_step(r, 1) + assert not r.finished, "req finished before abort could act" + + t.abort(r) + saw_abort_reason = False + for _ in range(16): + yield + req = r.req + if req is not None and isinstance(req.finished_reason, FINISH_ABORT): + saw_abort_reason = True + if r.finished: + break + assert r.finished, "single-handle abort did not finish the req" + assert saw_abort_reason, "aborted req never carried a FINISH_ABORT reason" + + def test_abort_single_handle_leaves_other_reqs_running(self): + self.server.execute_script(self._script_abort_single_handle_targeted) + + @staticmethod + def _script_abort_single_handle_targeted(t: ScriptedContext): + keep = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=64) + victim = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=64) + yield from advance_to_decode_step(keep, 1) + + t.abort(victim) + for _ in range(16): + yield + if victim.finished: + break + assert victim.finished, "targeted abort did not finish the victim req" + assert not keep.finished, "targeted abort wrongly finished the other req" + + t.abort(keep) + yield from run_until_finished(keep) + + def test_list_active_reqs_contains_live_req(self): + self.server.execute_script(self._script_list_active_reqs) + + @staticmethod + def _script_list_active_reqs(t: ScriptedContext): + assert t.list_active_reqs() == [], "no active reqs expected when idle" + r = t.start_req( + prompt_len=_SHORT_PROMPT_LEN, + max_new_tokens=_DECODE_MAX_NEW_TOKENS, + ignore_eos=True, + ) + yield from advance_to_decode_step(r, 1) + actives = t.list_active_reqs() + assert any(req.rid == r.rid for req in actives), ( + f"live req {r.rid!r} missing from active reqs " + f"{[req.rid for req in actives]!r}" + ) + yield from run_until_finished(r) + for _ in range(5): + yield + assert ( + t.list_active_reqs() == [] + ), "active reqs should drain to empty after finish" + + def test_kv_pages_held_during_run_released_after(self): + self.server.execute_script(self._script_kv_pages_set_then_released) + + @staticmethod + def _script_kv_pages_set_then_released(t: ScriptedContext): + r = t.start_req( + prompt_len=_LONG_PROMPT_LEN, + max_new_tokens=_DECODE_MAX_NEW_TOKENS, + ignore_eos=True, + ) + yield from advance_to_decode_step(r, 1) + assert r.kv_pages > 0, f"expected kv_pages>0 mid-run; got {r.kv_pages}" + yield from run_until_finished(r) + assert r.kv_pages == 0, f"kv_pages should be 0 after finish; got {r.kv_pages}" + + def test_engine_stats_tracks_kv_pool(self): + self.server.execute_script(self._script_engine_stats_tracks_kv) + + @staticmethod + def _script_engine_stats_tracks_kv(t: ScriptedContext): + stats = t.engine_stats() + for key in ("kv_pool_free", "req_pool_free", "req_pool_total", "page_size"): + assert key in stats, f"engine_stats missing {key!r}: {stats!r}" + baseline_free = stats["kv_pool_free"] + assert baseline_free > 0 + r = t.start_req( + prompt_len=_LONG_PROMPT_LEN, + max_new_tokens=_DECODE_MAX_NEW_TOKENS, + ignore_eos=True, + ) + yield from advance_to_decode_step(r, 1) + during_free = t.engine_stats()["kv_pool_free"] + assert during_free < baseline_free, ( + f"kv_pool_free should drop while a req holds KV; " + f"baseline={baseline_free} during={during_free}" + ) + yield from run_until_finished(r) + for _ in range(5): + yield + t.flush_cache() + yield + after_free = t.engine_stats()["kv_pool_free"] + assert ( + after_free >= during_free + ), f"kv_pool_free should recover after finish; during={during_free} after={after_free}" + + def test_lock_refs_held_during_run_released_after(self): + self.server.execute_script(self._script_lock_refs_held_then_released) + + @staticmethod + def _script_lock_refs_held_then_released(t: ScriptedContext): + t.flush_cache() + yield + r = t.start_req( + prompt_len=_LONG_PROMPT_LEN, + max_new_tokens=_DECODE_MAX_NEW_TOKENS, + ignore_eos=True, + ) + yield from advance_to_decode_step(r, 1) + assert ( + r.lock_refs >= 1 + ), f"radix lock_ref must be held mid-run; got {r.lock_refs}" + yield from run_until_finished(r) + assert ( + r.lock_refs == 0 + ), f"lock_refs must be released after finish; got {r.lock_refs}" + + def test_batch_composition_shape_and_disjoint(self): + self.server.execute_script(self._script_batch_composition) + + @staticmethod + def _script_batch_composition(t: ScriptedContext): + r = t.start_req(prompt_len=_LONG_PROMPT_LEN, max_new_tokens=2) + yield from advance_to_nth_chunk(r, 1) + comp = t.batch_composition() + assert set(comp) == { + "prefill", + "decode", + "chunked", + "running", + }, f"unexpected batch_composition keys: {comp!r}" + assert ( + r.rid in comp["chunked"] + ), f"chunked req must be in 'chunked'; got {comp!r}" + prefill, decode, chunked = ( + set(comp["prefill"]), + set(comp["decode"]), + set(comp["chunked"]), + ) + assert ( + prefill.isdisjoint(decode) + and prefill.isdisjoint(chunked) + and decode.isdisjoint(chunked) + ), f"prefill/decode/chunked subsets must be disjoint; got {comp!r}" + yield from run_until_finished(r) + assert ( + t.batch_composition()["chunked"] == [] + ), "no chunked req should remain after the req finishes" + + def test_chunks_done_zero_for_unchunked_prompt(self): + self.server.execute_script(self._script_chunks_done_zero) + + @staticmethod + def _script_chunks_done_zero(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=2, ignore_eos=True) + yield from run_until_finished(r) + assert ( + r.chunks_done == 0 + ), f"prompt <= chunk must not chunk; got {r.chunks_done}" + + def test_chunks_done_counts_two_chunks(self): + self.server.execute_script(self._script_chunks_done_two) + + @staticmethod + def _script_chunks_done_two(t: ScriptedContext): + r = t.start_req(prompt_len=_CHUNK_SIZE + 2, max_new_tokens=2, ignore_eos=True) + yield from run_until_finished(r) + assert ( + r.chunks_done == 2 + ), f"chunk_size+2 prompt -> 2 chunks; got {r.chunks_done}" + + def test_chunks_done_scales_with_prompt(self): + self.server.execute_script(self._script_chunks_done_five) + + @staticmethod + def _script_chunks_done_five(t: ScriptedContext): + r = t.start_req(prompt_len=5 * _CHUNK_SIZE, max_new_tokens=2, ignore_eos=True) + yield from run_until_finished(r) + assert ( + r.chunks_done == 5 + ), f"5*chunk_size prompt -> 5 chunks; got {r.chunks_done}" + + def test_is_idle_reflects_engine_activity(self): + self.server.execute_script(self._script_is_idle_reflects_activity) + + @staticmethod + def _script_is_idle_reflects_activity(t: ScriptedContext): + assert t.is_idle, "engine should be idle at script start" + r = t.start_req( + prompt_len=_SHORT_PROMPT_LEN, + max_new_tokens=_DECODE_MAX_NEW_TOKENS, + ignore_eos=True, + ) + yield from advance_to_decode_step(r, 1) + assert not t.is_idle, "is_idle True while a req is decoding" + yield from run_until_finished(r) + for _ in range(5): + yield + assert t.is_idle, "is_idle False after the req drained" + assert t.is_fully_idle, "is_fully_idle False after the req drained" + + def test_status_transitions_running_to_finished(self): + self.server.execute_script(self._script_status_transitions) + + @staticmethod + def _script_status_transitions(t: ScriptedContext): + assert ( + t.status("no-such-rid") == "unknown" + ), "status of a never-seen rid must be 'unknown'" + r = t.start_req( + prompt_len=_SHORT_PROMPT_LEN, + max_new_tokens=_DECODE_MAX_NEW_TOKENS, + ignore_eos=True, + ) + yield from advance_to_decode_step(r, 1) + assert ( + r.status == "running" + ), f"decoding req status should be running; got {r.status!r}" + yield from run_until_finished(r) + assert ( + r.status == "finished" + ), f"completed req status should be finished; got {r.status!r}" + + def test_last_batch_forward_mode_extend_then_decode(self): + self.server.execute_script(self._script_last_batch_forward_mode) + + @staticmethod + def _script_last_batch_forward_mode(t: ScriptedContext): + r = t.start_req( + prompt_len=_LONG_PROMPT_LEN, + max_new_tokens=_DECODE_MAX_NEW_TOKENS, + ignore_eos=True, + ) + yield from advance_to_nth_chunk(r, 1) + assert t.last_batch_forward_mode in ("EXTEND", "MIXED"), ( + f"mid-prefill batch should be an extend mode; " + f"got {t.last_batch_forward_mode!r}" + ) + yield from advance_to_decode_step(r, 1) + assert ( + t.last_batch_forward_mode == "DECODE" + ), f"decode batch mode should be DECODE; got {t.last_batch_forward_mode!r}" + yield from run_until_finished(r) + + def test_remaining_prompt_tokens_shrinks_to_zero(self): + self.server.execute_script(self._script_remaining_prompt_tokens) + + @staticmethod + def _script_remaining_prompt_tokens(t: ScriptedContext): + r = t.start_req(prompt_len=_LONG_PROMPT_LEN, max_new_tokens=2) + yield from advance_to_nth_chunk(r, 1) + rem = r.remaining_prompt_tokens + assert 0 < rem < _LONG_PROMPT_LEN, ( + f"mid-prefill remaining_prompt_tokens should be partial; " + f"got {rem} (prompt={_LONG_PROMPT_LEN})" + ) + yield from run_until_finished(r) + assert ( + r.remaining_prompt_tokens == 0 + ), f"finished req should have 0 remaining; got {r.remaining_prompt_tokens}" + + def test_evict_radix_full_clears_tree_and_rejects_prefix(self): + self.server.execute_script(self._script_evict_radix) + + @staticmethod + def _script_evict_radix(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=2) + yield from run_until_finished(r) + for _ in range(3): + yield + assert t.get_all_node_hit_counts(), "expected radix nodes after a finished req" + + t.evict_radix(prefix_tokens=None) + yield + assert ( + not t.get_all_node_hit_counts() + ), "evict_radix(prefix_tokens=None) did not clear the radix tree" + + rejected = False + try: + t.evict_radix(prefix_tokens=[1, 2, 3]) + except AssertionError: + rejected = True + assert ( + rejected + ), "evict_radix must reject a non-None prefix (only full evict supported)" + + def test_warmup_radix_populates_prefix(self): + self.server.execute_script(self._script_warmup_radix_populates_prefix) + + @staticmethod + def _script_warmup_radix_populates_prefix(t: ScriptedContext): + t.flush_cache() + yield + yield from warmup_radix(t, [1] * (2 * _CHUNK_SIZE)) + + r = t.start_req(prompt_len=2 * _CHUNK_SIZE + 1, max_new_tokens=1) + yield from run_until_finished(r) + assert ( + r.req is not None + ), "finished req vanished before cached_tokens could be read" + assert r.req.cached_tokens > 0, ( + f"req with the warmed prefix should hit the radix cache; " + f"got cached_tokens={r.req.cached_tokens}" + ) + + def test_exhaust_kv_creates_pressure_and_release_restores(self): + self.server.execute_script(self._script_exhaust_kv_round_trip) + + @staticmethod + def _script_exhaust_kv_round_trip(t: ScriptedContext): + stats = t.engine_stats() + page = stats["page_size"] + baseline = stats["kv_pool_free"] + assert ( + baseline > 4 * page + ), f"need KV headroom to test pressure; baseline={baseline}" + + t.exhaust_kv(leave_pages=2) + pressured = t.engine_stats()["kv_pool_free"] + assert ( + pressured < baseline + ), f"exhaust_kv must reduce free KV; pressured={pressured} baseline={baseline}" + assert ( + pressured <= 3 * page + ), f"exhaust_kv(leave_pages=2) left too much free KV; got {pressured} (page={page})" + + t._release_exhausted_pools() + restored = t.engine_stats()["kv_pool_free"] + assert ( + restored == baseline + ), f"release must restore the full pool; restored={restored} baseline={baseline}" + yield + + def test_exhaust_row_pool_leaves_requested_free_rows(self): + self.server.execute_script(self._script_exhaust_row_pool) + + @staticmethod + def _script_exhaust_row_pool(t: ScriptedContext): + avail = t.engine_stats()["req_pool_free"] + assert avail >= 5, f"need free rows to test; avail={avail}" + target_free = avail - 3 + + yield from exhaust_row_pool(t, leave_rows=target_free) + + free_after = t.engine_stats()["req_pool_free"] + assert ( + free_after <= target_free + ), f"exhaust_row_pool should leave <= {target_free} free rows; got {free_after}" + assert ( + free_after < avail + ), f"exhaust_row_pool did not consume any rows; avail={avail} after={free_after}" + + def test_forward_ct_advances_once_per_yield(self): + self.server.execute_script(self._script_forward_ct_advances_once_per_yield) + + @staticmethod + def _script_forward_ct_advances_once_per_yield(t: ScriptedContext): + sched = t.scheduler + + before_no_yield = sched.forward_ct + r = t.start_req( + prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=128, ignore_eos=True + ) + assert ( + sched.forward_ct == before_no_yield + ), f"forward_ct moved without a yield: {before_no_yield} -> {sched.forward_ct}" + + yield from advance_to_decode_step(r, 1) + + for n in (1, 3, 5): + before = sched.forward_ct + for _ in range(n): + yield + advanced = sched.forward_ct - before + assert advanced == n, ( + f"forward_ct advanced by {advanced} over {n} yields " + f"(before={before} after={sched.forward_ct})" + ) + + t.abort_all() + for _ in range(8): + yield + if r.finished: + break + + def test_empty_script_returns_immediately(self): + self.server.execute_script(self._script_empty_return) + + @staticmethod + def _script_empty_return(t: ScriptedContext): + if False: + yield + return + + def test_failing_script_surfaces_and_session_survives(self): + with self.assertRaises(AssertionError) as ctx: + self.server.execute_script(self._script_assertion_failure) + self.assertIn("boom", str(ctx.exception)) + self.server.execute_script(self._script_minimal_ok) + + @staticmethod + def _script_assertion_failure(t: ScriptedContext): + yield + assert False, "boom" + + @staticmethod + def _script_minimal_ok(t: ScriptedContext): + r = t.start_req(prompt_len=_SHORT_PROMPT_LEN, max_new_tokens=2) + yield + yield + + def test_runtime_error_in_script_surfaces_to_caller(self): + with self.assertRaises(AssertionError) as ctx: + self.server.execute_script(self._script_runtime_error) + err_text = str(ctx.exception) + self.assertIn("RuntimeError", err_text) + self.assertIn("simulated runtime error", err_text) + + @staticmethod + def _script_runtime_error(t: ScriptedContext): + yield + raise RuntimeError("simulated runtime error") + + def test_non_generator_script_rejected(self): + with self.assertRaises(AssertionError) as ctx: + self.server.execute_script(self._script_not_a_generator) + err_text = str(ctx.exception) + self.assertIn("TypeError", err_text) + self.assertIn("NoneType", err_text) + + @staticmethod + def _script_not_a_generator(t: ScriptedContext): + return None + + +class TestScriptedRuntimeSession(CustomTestCase): + + def test_shutdown_is_idempotent(self): + session = ScriptedHttpServer.start(**_ENGINE_KWARGS) + session.shutdown() + session.shutdown() + assert session._shutdown_done is True + + def test_dirty_session_refuses_to_run(self): + session = ScriptedHttpServer.start(**_ENGINE_KWARGS) + try: + session._dirty = "test dirty" + with self.assertRaises(RuntimeError) as ctx: + session.execute_script(_script_noop) + assert "test dirty" in str(ctx.exception) + finally: + session.shutdown() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/scripted_runtime/test_background_http_poster.py b/test/registered/unit/scripted_runtime/test_background_http_poster.py new file mode 100644 index 000000000000..12da761dfce0 --- /dev/null +++ b/test/registered/unit/scripted_runtime/test_background_http_poster.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import asyncio +import threading +import unittest +from concurrent.futures import Future +from unittest.mock import MagicMock + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.scripted_runtime import background_http_poster as bg_poster +from sglang.test.scripted_runtime.background_http_poster import BackgroundHttpPoster +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=8, suite="base-a-test-cpu") + + +class _FakeResponse: + + def __init__(self) -> None: + self.read_called = False + + async def read(self) -> bytes: + self.read_called = True + return b"chunk-1chunk-2" + + +class _FakePostCM: + + def __init__(self, response: _FakeResponse) -> None: + self._response = response + + async def __aenter__(self) -> _FakeResponse: + return self._response + + async def __aexit__(self, *exc_info: object) -> bool: + return False + + +class _FakeSession: + + def __init__(self) -> None: + self.closed = False + self.calls: list[tuple[str, object]] = [] + self.response = _FakeResponse() + + def post(self, url: str, json: object) -> _FakePostCM: + self.calls.append((url, json)) + return _FakePostCM(self.response) + + async def close(self) -> None: + self.closed = True + + +class TestBackgroundHttpPosterLifecycle(CustomTestCase): + + def test_init_starts_running_loop_on_daemon_thread(self): + poster = BackgroundHttpPoster() + self.addCleanup(poster.close) + + self.assertIsNotNone(poster._loop) + self.assertTrue(poster._loop.is_running()) + self.assertTrue(poster._thread.is_alive()) + self.assertTrue(poster._thread.daemon) + + def test_close_stops_loop_and_joins_thread(self): + poster = BackgroundHttpPoster() + + poster.close() + + self.assertFalse(poster._loop.is_running()) + self.assertFalse(poster._thread.is_alive()) + + def test_close_is_safe_when_loop_never_started(self): + poster = BackgroundHttpPoster.__new__(BackgroundHttpPoster) + poster._loop = None + poster._thread = None + poster._session = None + + poster.close() + + +class TestBackgroundHttpPosterSubmitCoro(CustomTestCase): + + def test_submit_coro_runs_on_background_loop_thread(self): + poster = BackgroundHttpPoster() + self.addCleanup(poster.close) + done = threading.Event() + recorded: dict[str, str] = {} + + async def record_thread() -> None: + recorded["thread_name"] = threading.current_thread().name + done.set() + + poster.submit_coro(record_thread()) + + self.assertTrue(done.wait(timeout=5.0)) + self.assertEqual(recorded["thread_name"], "scripted-runtime-async") + + def test_log_coro_exception_logs_real_failure(self): + future: Future = Future() + future.set_exception(RuntimeError("boom")) + + original = bg_poster.logger.exception + bg_poster.logger.exception = MagicMock() + try: + BackgroundHttpPoster._log_coro_exception(future) + bg_poster.logger.exception.assert_called_once() + finally: + bg_poster.logger.exception = original + + def test_log_coro_exception_swallows_cancellation_silently(self): + future: Future = Future() + future.set_exception(asyncio.CancelledError()) + + original = bg_poster.logger.exception + bg_poster.logger.exception = MagicMock() + try: + BackgroundHttpPoster._log_coro_exception(future) + bg_poster.logger.exception.assert_not_called() + finally: + bg_poster.logger.exception = original + + def test_log_coro_exception_quiet_on_success(self): + future: Future = Future() + future.set_result(None) + + original = bg_poster.logger.exception + bg_poster.logger.exception = MagicMock() + try: + BackgroundHttpPoster._log_coro_exception(future) + bg_poster.logger.exception.assert_not_called() + finally: + bg_poster.logger.exception = original + + +class TestBackgroundHttpPosterEnsureSession(CustomTestCase): + + def test_ensure_session_creates_reuses_then_recreates_when_closed(self): + poster = BackgroundHttpPoster() + self.addCleanup(poster.close) + sessions = [MagicMock(closed=False), MagicMock(closed=False)] + original = bg_poster.aiohttp.ClientSession + original_connector = bg_poster.aiohttp.TCPConnector + bg_poster.aiohttp.ClientSession = MagicMock(side_effect=sessions) + bg_poster.aiohttp.TCPConnector = MagicMock() + try: + first = poster._ensure_session() + self.assertIs(first, sessions[0]) + + reused = poster._ensure_session() + self.assertIs(reused, sessions[0]) + + sessions[0].closed = True + recreated = poster._ensure_session() + self.assertIs(recreated, sessions[1]) + finally: + bg_poster.aiohttp.ClientSession = original + bg_poster.aiohttp.TCPConnector = original_connector + poster._session = None + + +class TestBackgroundHttpPosterPost(CustomTestCase): + + def _run_on_loop(self, poster: BackgroundHttpPoster, coro) -> None: + asyncio.run_coroutine_threadsafe(coro, poster._loop).result(timeout=5.0) + + def test_post_posts_json_and_reads_body(self): + poster = BackgroundHttpPoster() + self.addCleanup(poster.close) + session = _FakeSession() + poster._ensure_session = lambda: session + + self._run_on_loop(poster, poster.post("http://h/flush", {"a": 1})) + + self.assertEqual(session.calls, [("http://h/flush", {"a": 1})]) + self.assertTrue(session.response.read_called) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/scripted_runtime/test_http_server.py b/test/registered/unit/scripted_runtime/test_http_server.py new file mode 100644 index 000000000000..7134d955efab --- /dev/null +++ b/test/registered/unit/scripted_runtime/test_http_server.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import unittest + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.scripted_runtime.http_server import ScriptedHttpServer +from sglang.test.scripted_runtime.io_struct import ( + HookReady, + RunScript, + ScriptFailed, + ScriptSucceeded, +) +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=15, suite="base-a-test-cpu") + + +def _sample_script(ctx, *args): + yield + + +_EXPECTED_FN_PATH = f"{_sample_script.__module__}:{_sample_script.__qualname__}" + + +class _FakePairSocket: + + def __init__(self, *, poll_result: bool, reply: object = None) -> None: + self._poll_result = poll_result + self._reply = reply + self.sent: list = [] + + def send_pyobj(self, obj: object) -> None: + self.sent.append(obj) + + def poll(self, timeout_ms: int) -> bool: + return self._poll_result + + def recv_pyobj(self) -> object: + return self._reply + + +class _FakeProcess: + + def __init__(self, *, alive: bool) -> None: + self._alive = alive + + def is_alive(self) -> bool: + return self._alive + + +def _make_server(socket: _FakePairSocket, process: _FakeProcess) -> ScriptedHttpServer: + server = ScriptedHttpServer.__new__(ScriptedHttpServer) + server._socket = socket + server._server_process = process + server._dirty = None + return server + + +class TestExecuteScriptReplyMatching(CustomTestCase): + + def test_returns_on_script_succeeded(self): + socket = _FakePairSocket(poll_result=True, reply=ScriptSucceeded()) + server = _make_server(socket, _FakeProcess(alive=True)) + + server.execute_script(_sample_script) + + self.assertEqual(socket.sent, [RunScript(fn_path=_EXPECTED_FN_PATH, args=())]) + + def test_forwards_args_in_run_script(self): + socket = _FakePairSocket(poll_result=True, reply=ScriptSucceeded()) + server = _make_server(socket, _FakeProcess(alive=True)) + + server.execute_script(_sample_script, args=(1, "two")) + + self.assertEqual( + socket.sent, [RunScript(fn_path=_EXPECTED_FN_PATH, args=(1, "two"))] + ) + + def test_script_failed_reply_raises_assertion_with_traceback(self): + socket = _FakePairSocket( + poll_result=True, reply=ScriptFailed(traceback="REMOTE-TB-MARKER") + ) + server = _make_server(socket, _FakeProcess(alive=True)) + + with self.assertRaisesRegex(AssertionError, "REMOTE-TB-MARKER"): + server.execute_script(_sample_script) + + def test_unexpected_reply_raises_runtime_error(self): + socket = _FakePairSocket(poll_result=True, reply=HookReady()) + server = _make_server(socket, _FakeProcess(alive=True)) + + with self.assertRaisesRegex(RuntimeError, "unexpected message"): + server.execute_script(_sample_script) + + +class TestExecuteScriptNoReply(CustomTestCase): + + def test_timeout_when_process_still_alive(self): + socket = _FakePairSocket(poll_result=False) + server = _make_server(socket, _FakeProcess(alive=True)) + + with self.assertRaisesRegex(TimeoutError, "timed out"): + server.execute_script(_sample_script, timeout_s=0.01) + self.assertIn("timed out", server._dirty) + + def test_runtime_error_when_process_died(self): + socket = _FakePairSocket(poll_result=False) + server = _make_server(socket, _FakeProcess(alive=False)) + + with self.assertRaisesRegex(RuntimeError, "died before responding"): + server.execute_script(_sample_script, timeout_s=0.01) + self.assertIn("died before responding", server._dirty) + + +class TestExecuteScriptDirtyGuard(CustomTestCase): + + def test_refuses_to_run_when_already_dirty(self): + socket = _FakePairSocket(poll_result=True, reply=ScriptSucceeded()) + server = _make_server(socket, _FakeProcess(alive=True)) + server._dirty = "prior timeout" + + with self.assertRaisesRegex(RuntimeError, "dirty"): + server.execute_script(_sample_script) + self.assertEqual(socket.sent, []) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/scripted_runtime/test_scheduler_hook.py b/test/registered/unit/scripted_runtime/test_scheduler_hook.py new file mode 100644 index 000000000000..2044fa4aa91e --- /dev/null +++ b/test/registered/unit/scripted_runtime/test_scheduler_hook.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.scripted_runtime import scheduler_hook +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="base-a-test-cpu") + + +def _yielding_gen(): + yield + yield + + +def _empty_gen(): + return + yield # pragma: no cover — makes this a generator function + + +def _raising_gen(): + raise ValueError("scripted-boom") + yield # pragma: no cover — makes this a generator function + + +class TestAdvanceGenerator(CustomTestCase): + + def test_not_done_when_generator_yields(self): + done, exc_tb = scheduler_hook._advance_generator(_yielding_gen()) + + self.assertEqual((done, exc_tb), (False, None)) + + def test_done_without_traceback_on_stop_iteration(self): + done, exc_tb = scheduler_hook._advance_generator(_empty_gen()) + + self.assertEqual((done, exc_tb), (True, None)) + + def test_done_with_traceback_on_exception(self): + original = scheduler_hook.logger.exception + scheduler_hook.logger.exception = MagicMock() + try: + done, exc_tb = scheduler_hook._advance_generator(_raising_gen()) + scheduler_hook.logger.exception.assert_called_once() + finally: + scheduler_hook.logger.exception = original + + self.assertTrue(done) + self.assertIsNotNone(exc_tb) + self.assertIn("ValueError", exc_tb) + self.assertIn("scripted-boom", exc_tb) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/scripted_runtime/test_scripted_runtime_utils.py b/test/registered/unit/scripted_runtime/test_scripted_runtime_utils.py new file mode 100644 index 000000000000..29464e4c181a --- /dev/null +++ b/test/registered/unit/scripted_runtime/test_scripted_runtime_utils.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import json +import os +import sys +import unittest + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.scripted_runtime.utils import ensure_script_importable, resolve_fn +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=2, suite="base-a-test-cpu") + + +class TestResolveFn(CustomTestCase): + + def test_resolves_top_level_function(self): + self.assertIs(resolve_fn("json:dumps"), json.dumps) + + def test_resolves_nested_attribute_path(self): + self.assertIs(resolve_fn("os:path.join"), os.path.join) + + def test_rejects_missing_colon(self): + with self.assertRaisesRegex(ValueError, "module.path:function_name"): + resolve_fn("json.dumps") + + def test_rejects_empty_module(self): + with self.assertRaisesRegex(ValueError, "module.path:function_name"): + resolve_fn(":dumps") + + def test_rejects_empty_function(self): + with self.assertRaisesRegex(ValueError, "module.path:function_name"): + resolve_fn("json:") + + def test_rejects_non_callable_target(self): + with self.assertRaisesRegex(TypeError, "not callable"): + resolve_fn("math:pi") + + def test_propagates_missing_module_error(self): + with self.assertRaises(ModuleNotFoundError): + resolve_fn("sglang_no_such_module_zzz:foo") + + def test_propagates_missing_attribute_error(self): + with self.assertRaises(AttributeError): + resolve_fn("json:no_such_attribute") + + +class TestEnsureScriptImportable(CustomTestCase): + + _FAKE_ENTRY = "/tmp/__scripted_runtime_ut_fake_sys_path__" + + def setUp(self): + self._orig_path = list(sys.path) + + def tearDown(self): + sys.path[:] = self._orig_path + + def test_inserts_new_entry_at_front(self): + self.assertNotIn(self._FAKE_ENTRY, sys.path) + + ensure_script_importable(self._FAKE_ENTRY) + + self.assertEqual(sys.path[0], self._FAKE_ENTRY) + + def test_noop_when_entry_is_none(self): + ensure_script_importable(None) + + self.assertEqual(sys.path, self._orig_path) + + def test_noop_when_entry_already_present(self): + sys.path.insert(0, self._FAKE_ENTRY) + + ensure_script_importable(self._FAKE_ENTRY) + + self.assertEqual(sys.path.count(self._FAKE_ENTRY), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/scripted_runtime/test_tokenizer_recv_proxy.py b/test/registered/unit/scripted_runtime/test_tokenizer_recv_proxy.py new file mode 100644 index 000000000000..8adbe40f6471 --- /dev/null +++ b/test/registered/unit/scripted_runtime/test_tokenizer_recv_proxy.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass + +import zmq + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.scripted_runtime.tokenizer_recv_proxy import ( + ScriptedTokenizerRecvProxy, +) +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=5, suite="base-a-test-cpu") + +import unittest + + +@dataclass +class _ControlMsg: + + tag: str = "flush" + + +@dataclass +class _StartReq: + + rid: str + + +class _FakeUnderlyingSocket: + + def __init__(self) -> None: + self._ready: deque = deque() + self._scheduled: list[list] = [] + + def feed(self, obj: object) -> None: + self._ready.append(obj) + + def feed_after_drain_cycles(self, obj: object, *, cycles: int) -> None: + self._scheduled.append([cycles, obj]) + + def recv_pyobj(self, flags: int = 0) -> object: + if self._ready: + return self._ready.popleft() + + for entry in self._scheduled: + entry[0] -= 1 + ready_now = [obj for remaining, obj in self._scheduled if remaining <= 0] + self._scheduled = [entry for entry in self._scheduled if entry[0] > 0] + self._ready.extend(ready_now) + + raise zmq.ZMQError(zmq.EAGAIN, "Resource temporarily unavailable") + + +def _is_control(obj: object) -> bool: + return isinstance(obj, _ControlMsg) + + +def _is_start_req(rid: str): + return lambda obj: isinstance(obj, _StartReq) and obj.rid == rid + + +class TestScriptedTokenizerRecvProxyRecv(CustomTestCase): + + def test_recv_pyobj_drains_then_pops_fifo(self): + underlying = _FakeUnderlyingSocket() + proxy = ScriptedTokenizerRecvProxy(underlying=underlying) + first, second = _ControlMsg("a"), _ControlMsg("b") + underlying.feed(first) + underlying.feed(second) + + self.assertIs(proxy.recv_pyobj(), first) + self.assertIs(proxy.recv_pyobj(), second) + + def test_recv_pyobj_empty_noblock_raises_eagain(self): + proxy = ScriptedTokenizerRecvProxy(underlying=_FakeUnderlyingSocket()) + + with self.assertRaises(zmq.ZMQError) as ctx: + proxy.recv_pyobj(zmq.NOBLOCK) + self.assertEqual(ctx.exception.errno, zmq.EAGAIN) + + def test_recv_pyobj_empty_blocking_raises_runtime_error(self): + proxy = ScriptedTokenizerRecvProxy(underlying=_FakeUnderlyingSocket()) + + with self.assertRaisesRegex(RuntimeError, "blocking recv is not supported"): + proxy.recv_pyobj() + + +class TestScriptedTokenizerRecvProxyWaitUntilArrived(CustomTestCase): + + def _proxy_with_stale_control(self): + underlying = _FakeUnderlyingSocket() + proxy = ScriptedTokenizerRecvProxy(underlying=underlying) + stale = _ControlMsg("stale") + underlying.feed(stale) + proxy.wait_until_arrived(_is_control, timeout_s=1.0) + return proxy, underlying, stale + + def test_wait_until_arrived_returns_on_first_match_when_buffer_empty(self): + underlying = _FakeUnderlyingSocket() + proxy = ScriptedTokenizerRecvProxy(underlying=underlying) + msg = _ControlMsg("first") + underlying.feed(msg) + + proxy.wait_until_arrived(_is_control, timeout_s=1.0) + + self.assertIs(proxy.recv_pyobj(), msg) + + def test_wait_until_arrived_skips_stale_same_type_object(self): + proxy, _, _ = self._proxy_with_stale_control() + + with self.assertRaises(TimeoutError): + proxy.wait_until_arrived(_is_control, timeout_s=0.05) + + def test_wait_until_arrived_returns_on_new_object_after_stale(self): + proxy, underlying, stale = self._proxy_with_stale_control() + fresh = _ControlMsg("fresh") + underlying.feed_after_drain_cycles(fresh, cycles=1) + + proxy.wait_until_arrived(_is_control, timeout_s=2.0) + + self.assertIs(proxy.recv_pyobj(), stale) + self.assertIs(proxy.recv_pyobj(), fresh) + + def test_wait_until_arrived_rid_predicate_ignores_stale_other_rid(self): + underlying = _FakeUnderlyingSocket() + proxy = ScriptedTokenizerRecvProxy(underlying=underlying) + old = _StartReq(rid="old") + underlying.feed(old) + proxy.wait_until_arrived(_is_start_req("old"), timeout_s=1.0) + + new = _StartReq(rid="new") + underlying.feed(new) + proxy.wait_until_arrived(_is_start_req("new"), timeout_s=1.0) + + self.assertIs(proxy.recv_pyobj(), old) + self.assertIs(proxy.recv_pyobj(), new) + + def test_wait_until_arrived_rid_predicate_skips_stale_same_rid(self): + underlying = _FakeUnderlyingSocket() + proxy = ScriptedTokenizerRecvProxy(underlying=underlying) + underlying.feed(_StartReq(rid="reused")) + proxy.wait_until_arrived(_is_start_req("reused"), timeout_s=1.0) + + with self.assertRaises(TimeoutError): + proxy.wait_until_arrived(_is_start_req("reused"), timeout_s=0.05) + + def test_wait_until_arrived_timeout_message_names_description(self): + proxy = ScriptedTokenizerRecvProxy(underlying=_FakeUnderlyingSocket()) + + with self.assertRaisesRegex(TimeoutError, "FlushCacheReqInput"): + proxy.wait_until_arrived( + _is_control, timeout_s=0.02, description="FlushCacheReqInput" + ) + + +if __name__ == "__main__": + unittest.main()