diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index ae3958d17830..dde12b65d725 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -753,6 +753,7 @@ class Envs: # KV-Canary / Token-Oracle (testing-only) # =================================================================== SGLANG_KV_CANARY_RING_CAPACITY = EnvInt(1024) + SGLANG_KV_CANARY_STATS_PRINT_EVERY_N_STEPS = EnvInt(100) SGLANG_KV_CANARY_ENABLE_WRITE_INPUT_ASSERT = EnvBool(False) SGLANG_KV_CANARY_PERTURB_REQ_TO_TOKEN_PROB = EnvFloat(0.0) SGLANG_KV_CANARY_PERTURB_WARMUP_STEPS = EnvInt(50) diff --git a/python/sglang/srt/kv_canary/config.py b/python/sglang/srt/kv_canary/config.py index 2ae617ee5081..0eb0ea194ec8 100644 --- a/python/sglang/srt/kv_canary/config.py +++ b/python/sglang/srt/kv_canary/config.py @@ -47,6 +47,8 @@ class CanaryConfig: expected_tokens from each req's ``origin_input_ids + output_ids`` (snapshotted at ForwardBatch.init_new) and compare against the canary's stored tokens at verify time. Independent of ``enable_write_input_assert``. + stats_print_every_n_steps: 0 disables periodic stats logging; positive N prints + "canary protected N tokens, ran M sweep passes, K violations so far" every N forward steps. """ mode: CanaryMode @@ -55,6 +57,7 @@ class CanaryConfig: real_kv_hash_mode: RealKvHashMode enable_write_input_assert: bool enable_verify_token_assert: bool + stats_print_every_n_steps: int @classmethod def from_env(cls, server_args: "ServerArgs") -> "CanaryConfig": @@ -73,4 +76,5 @@ def from_env(cls, server_args: "ServerArgs") -> "CanaryConfig": real_kv_hash_mode=RealKvHashMode[real_kv_raw], enable_write_input_assert=envs.SGLANG_KV_CANARY_ENABLE_WRITE_INPUT_ASSERT.get(), enable_verify_token_assert=envs.SGLANG_KV_CANARY_ENABLE_VERIFY_TOKEN_ASSERT.get(), + stats_print_every_n_steps=envs.SGLANG_KV_CANARY_STATS_PRINT_EVERY_N_STEPS.get(), ) diff --git a/python/sglang/srt/kv_canary/runner/canary_manager.py b/python/sglang/srt/kv_canary/runner/canary_manager.py index 5357d567bafa..0b4656cfc97f 100644 --- a/python/sglang/srt/kv_canary/runner/canary_manager.py +++ b/python/sglang/srt/kv_canary/runner/canary_manager.py @@ -18,6 +18,8 @@ ) from sglang.srt.kv_canary.perturb.config import PerturbConfig from sglang.srt.kv_canary.perturb.manager import PerturbManager +from sglang.srt.kv_canary.runner.health_checker import KernelRunCounterHealthChecker +from sglang.srt.kv_canary.runner.stats_logger import PeriodicCanaryStatsLogger from sglang.srt.kv_canary.runner.swa_divergence import SwaDivergenceReporter from sglang.srt.kv_canary.runner.sweep import SweepOrchestrator from sglang.srt.kv_canary.runner.violation_manager import ViolationManager @@ -127,6 +129,22 @@ def __init__( swa_window_size=self._swa_window_size, sweep_interval=config.sweep_interval, ) + self._health_checker = KernelRunCounterHealthChecker( + config=config, + device_state=self._device_state, + active_tags=self._active_tags, + outer_step_counter_getter=self._get_outer_step_counter, + d2h_stream=self._d2h_stream, + ) + self._stats_logger = PeriodicCanaryStatsLogger( + config=config, + device_state=self._device_state, + active_tags=self._active_tags, + outer_step_counter_getter=self._get_outer_step_counter, + sweep_orchestrator=self._sweep_orchestrator, + d2h_stream=self._d2h_stream, + ) + num_sfms = max(1, speculative_num_steps - 1) self._single_forward_managers: tuple[SingleForwardManager, ...] = tuple( SingleForwardManager( @@ -233,6 +251,8 @@ def _post_ops_outside_graph( self._sweep_orchestrator.maybe_run_sweep() self._outer_step_counter += 1 self._violation_manager.step() + self._health_checker.step() + self._stats_logger.step() if self._swa_divergence_report is not None: self._swa_divergence_report.step( outer_step_counter=self._outer_step_counter, diff --git a/python/sglang/srt/kv_canary/runner/health_checker.py b/python/sglang/srt/kv_canary/runner/health_checker.py new file mode 100644 index 000000000000..5fb3ce811b7e --- /dev/null +++ b/python/sglang/srt/kv_canary/runner/health_checker.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import Optional + +import torch + +from sglang.jit_kernel.kv_canary.verify import CanaryLaunchTag +from sglang.srt.kv_canary.config import CanaryConfig +from sglang.srt.kv_canary.runner.future_tensor import DelayedDeviceHostHandler +from sglang.srt.kv_canary.runner.kernel_launcher import passes_v_half_gate +from sglang.srt.kv_canary.state import CanaryDeviceState + +logger = logging.getLogger(__name__) + +_HEALTH_CHECK_EVERY_N_STEPS: int = 100 +_HEALTH_CHECK_WARMUP_STEPS: int = 100 +_SWEEP_TAGS: frozenset[CanaryLaunchTag] = frozenset( + ( + CanaryLaunchTag.SWEEP_K_FULL, + CanaryLaunchTag.SWEEP_V_FULL, + CanaryLaunchTag.SWEEP_K_SWA, + CanaryLaunchTag.SWEEP_V_SWA, + ) +) + + +class KernelRunCounterHealthChecker: + def __init__( + self, + *, + config: CanaryConfig, + device_state: CanaryDeviceState, + active_tags: tuple[CanaryLaunchTag, ...], + outer_step_counter_getter: Callable[[], int], + d2h_stream: torch.cuda.Stream, + ) -> None: + self._config = config + self._device_state = device_state + self._active_tags = active_tags + self._outer_step_counter_getter = outer_step_counter_getter + self._handler = DelayedDeviceHostHandler(d2h_stream=d2h_stream) + self._prev_counters_host: torch.Tensor = torch.zeros_like( + device_state.kernel_run_counters, device="cpu" + ) + + def step(self) -> None: + self._handler.step( + compute_on_device=self._compute_on_device, + postprocess_on_host=self._postprocess_on_host, + ) + + def _compute_on_device(self) -> Optional[torch.Tensor]: + outer_step_counter = self._outer_step_counter_getter() + if outer_step_counter < _HEALTH_CHECK_WARMUP_STEPS: + return None + if outer_step_counter % _HEALTH_CHECK_EVERY_N_STEPS != 0: + return None + if not self._active_tags: + return None + return self._device_state.kernel_run_counters + + def _postprocess_on_host(self, new_counter_host: torch.Tensor) -> None: + delta = new_counter_host - self._prev_counters_host + self._prev_counters_host = new_counter_host + expected_tags = self._expected_active_tags_for_health_check() + stalled = [tag for tag in expected_tags if int(delta[tag.value]) == 0] + if stalled: + names = ", ".join(tag.name for tag in stalled) + raise RuntimeError( + f"kv-canary: kernel_run_counter did not increase since previous check " + f"for tags=[{names}] at step={self._outer_step_counter_getter()}; " + f"canary path is not executing" + ) + + def _expected_active_tags_for_health_check(self) -> tuple[CanaryLaunchTag, ...]: + tags = self._active_tags + if self._config.sweep_interval <= 0: + tags = tuple(tag for tag in tags if tag not in _SWEEP_TAGS) + return tuple(tag for tag in tags if passes_v_half_gate(tag)) diff --git a/python/sglang/srt/kv_canary/runner/stats_logger.py b/python/sglang/srt/kv_canary/runner/stats_logger.py new file mode 100644 index 000000000000..e4af7ee4e789 --- /dev/null +++ b/python/sglang/srt/kv_canary/runner/stats_logger.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import Any, Optional + +import torch + +from sglang.jit_kernel.kv_canary.verify import CanaryLaunchTag +from sglang.srt.kv_canary.config import CanaryConfig +from sglang.srt.kv_canary.runner.future_tensor import DelayedDeviceHostHandler +from sglang.srt.kv_canary.runner.sweep import SweepOrchestrator +from sglang.srt.kv_canary.state import CanaryDeviceState + +logger = logging.getLogger(__name__) + + +class PeriodicCanaryStatsLogger: + def __init__( + self, + *, + config: CanaryConfig, + device_state: CanaryDeviceState, + active_tags: tuple[CanaryLaunchTag, ...], + outer_step_counter_getter: Callable[[], int], + sweep_orchestrator: SweepOrchestrator, + d2h_stream: torch.cuda.Stream, + ) -> None: + self._config = config + self._device_state = device_state + self._active_tags = active_tags + self._outer_step_counter_getter = outer_step_counter_getter + self._sweep_orchestrator = sweep_orchestrator + self._handler = DelayedDeviceHostHandler(d2h_stream=d2h_stream) + + def step(self) -> None: + self._handler.step( + compute_on_device=self._compute_on_device, + postprocess_on_host=self._postprocess_on_host, + ) + + def _compute_on_device(self) -> Optional[dict[str, Any]]: + period = self._config.stats_print_every_n_steps + if period <= 0: + return None + outer_step_counter = self._outer_step_counter_getter() + if outer_step_counter == 0 or outer_step_counter % period != 0: + return None + device_state = self._device_state + return { + "step": outer_step_counter, + "slot_sum": device_state.slot_run_counters.sum().view(1), + "write_index": device_state.violation_log.violation_write_index, + } + + def _postprocess_on_host(self, host_data: dict[str, Any]) -> None: + logger.info( + "[canary] step=%d protected_tokens=%d sweep_passes=%d violations=%d " + "launch_tags_active=%d/%d", + int(host_data["step"]), + int(host_data["slot_sum"].item()), + self._sweep_orchestrator.sweep_passes, + int(host_data["write_index"].item()), + len(self._active_tags), + len(CanaryLaunchTag), + ) diff --git a/python/sglang/test/kv_canary/fixtures.py b/python/sglang/test/kv_canary/fixtures.py index 6ac7f253a625..6c230c1aa49b 100644 --- a/python/sglang/test/kv_canary/fixtures.py +++ b/python/sglang/test/kv_canary/fixtures.py @@ -140,6 +140,7 @@ def make_base_config() -> CanaryConfig: real_kv_hash_mode=consts.RealKvHashMode.NONE, enable_write_input_assert=False, enable_verify_token_assert=True, + stats_print_every_n_steps=100, ) diff --git a/python/sglang/test/kv_canary/runner_test_base.py b/python/sglang/test/kv_canary/runner_test_base.py index 3ca41411ac44..517bdbeabb17 100644 --- a/python/sglang/test/kv_canary/runner_test_base.py +++ b/python/sglang/test/kv_canary/runner_test_base.py @@ -30,6 +30,7 @@ def make_config( real_kv_hash_mode: RealKvHashMode = RealKvHashMode.NONE, enable_write_input_assert: bool = False, enable_verify_token_assert: bool = True, + stats_print_every_n_steps: int = 100, ) -> CanaryConfig: return CanaryConfig( mode=mode, @@ -38,6 +39,7 @@ def make_config( real_kv_hash_mode=real_kv_hash_mode, enable_write_input_assert=enable_write_input_assert, enable_verify_token_assert=enable_verify_token_assert, + stats_print_every_n_steps=stats_print_every_n_steps, ) diff --git a/test/registered/kv_canary/test_self_unit_buffer_alloc.py b/test/registered/kv_canary/test_self_unit_buffer_alloc.py index 3872b6c7ac29..97d5d84354b0 100644 --- a/test/registered/kv_canary/test_self_unit_buffer_alloc.py +++ b/test/registered/kv_canary/test_self_unit_buffer_alloc.py @@ -26,6 +26,7 @@ def _config(mode: RealKvHashMode) -> CanaryConfig: real_kv_hash_mode=mode, enable_write_input_assert=False, enable_verify_token_assert=False, + stats_print_every_n_steps=100, ) diff --git a/test/registered/kv_canary/test_self_unit_runner_health.py b/test/registered/kv_canary/test_self_unit_runner_health.py new file mode 100644 index 000000000000..7255ba60ce95 --- /dev/null +++ b/test/registered/kv_canary/test_self_unit_runner_health.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import logging +import unittest +from unittest.mock import Mock + +import torch + +from sglang.jit_kernel.kv_canary.verify import CanaryLaunchTag +from sglang.srt.kv_canary.config import CanaryConfig +from sglang.srt.kv_canary.runner import stats_logger as stats_logger_module +from sglang.srt.kv_canary.runner.health_checker import KernelRunCounterHealthChecker +from sglang.srt.kv_canary.state import CanaryDeviceState +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.kv_canary.runner_test_base import ( + CanaryManagerTestCase, + make_config, + make_manager, +) +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=45, stage="extra-a", runner_config="1-gpu-small") + + +class TestSelfUnitManagerHealth(CanaryManagerTestCase): + def test_kernel_run_counter_watchdog_raises_on_zero(self) -> None: + """Verify the kernel watchdog raises when counters stop advancing.""" + manager = make_manager(device=self.device) + manager._outer_step_counter = 1000 + manager._device_state.kernel_run_counters.zero_() + manager._health_checker.step() + manager._outer_step_counter = 2000 + with self.assertRaises(RuntimeError): + manager._health_checker.step() + + def test_kernel_run_counter_watchdog_ignores_sweep_when_sweep_is_disabled( + self, + ) -> None: + """Verify the watchdog ignores disabled sweep counters.""" + config = make_config(sweep_interval=0) + manager = make_manager(device=self.device, config=config) + manager._device_state.kernel_run_counters.zero_() + for tag in ( + CanaryLaunchTag.HEAD_K_FULL, + CanaryLaunchTag.HEAD_V_FULL, + CanaryLaunchTag.TAIL_K_FULL, + CanaryLaunchTag.TAIL_V_FULL, + ): + manager._device_state.kernel_run_counters[tag.value] = 1 + + manager._outer_step_counter = 1000 + manager._health_checker.step() + manager._outer_step_counter = 2000 + manager._health_checker.step() + + def test_periodic_stats_log_every_n_step(self) -> None: + """Verify periodic stats are logged at the configured interval.""" + config = make_config(stats_print_every_n_steps=5) + manager = make_manager(device=self.device, config=config) + manager._device_state.slot_run_counters.fill_(7) + + with self.assertLogs(stats_logger_module.logger.name, level=logging.INFO) as cm: + for _ in range(11): + manager._stats_logger.step() + manager._outer_step_counter += 1 + log_text = "\n".join(cm.output) + self.assertIn("protected_tokens=", log_text) + self.assertTrue("step=5" in log_text or "step=10" in log_text) + + +class TestKernelRunCounterDeltaCheck(CustomTestCase): + """Pure host-side regression tests for the watchdog's delta semantics.""" + + def _make_checker( + self, + *, + active_tags: tuple[CanaryLaunchTag, ...], + outer_step: int, + ) -> KernelRunCounterHealthChecker: + config = Mock(spec=CanaryConfig) + config.sweep_interval = 0 + num_tags = len(CanaryLaunchTag) + device_state = Mock(spec=CanaryDeviceState) + device_state.kernel_run_counters = torch.zeros(num_tags, dtype=torch.int64) + return KernelRunCounterHealthChecker( + config=config, + device_state=device_state, + active_tags=active_tags, + outer_step_counter_getter=lambda: outer_step, + d2h_stream=Mock(), + ) + + def _host_tensor(self, value: int) -> torch.Tensor: + return torch.full((len(CanaryLaunchTag),), value, dtype=torch.int64) + + def test_first_check_raises_when_counter_never_incremented(self) -> None: + active_tags = (CanaryLaunchTag.HEAD_K_FULL, CanaryLaunchTag.TAIL_K_FULL) + checker = self._make_checker(active_tags=active_tags, outer_step=100) + + host_counters = torch.zeros(len(CanaryLaunchTag), dtype=torch.int64) + with self.assertRaises(RuntimeError) as cm: + checker._postprocess_on_host(host_counters) + message = str(cm.exception) + self.assertIn(CanaryLaunchTag.HEAD_K_FULL.name, message) + self.assertIn(CanaryLaunchTag.TAIL_K_FULL.name, message) + self.assertIn("did not increase", message) + + def test_second_check_raises_when_delta_is_zero(self) -> None: + active_tags = (CanaryLaunchTag.HEAD_K_FULL,) + checker = self._make_checker(active_tags=active_tags, outer_step=100) + + checker._postprocess_on_host(self._host_tensor(1)) + with self.assertRaises(RuntimeError) as cm: + checker._postprocess_on_host(self._host_tensor(1)) + message = str(cm.exception) + self.assertIn(CanaryLaunchTag.HEAD_K_FULL.name, message) + self.assertIn("did not increase", message) + + def test_no_raise_when_counter_increases(self) -> None: + active_tags = (CanaryLaunchTag.HEAD_K_FULL, CanaryLaunchTag.TAIL_K_FULL) + checker = self._make_checker(active_tags=active_tags, outer_step=100) + + for value in (1, 5, 12, 100): + checker._postprocess_on_host(self._host_tensor(value)) + + +if __name__ == "__main__": + unittest.main()