Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ class Envs:
# KV-Canary / Token-Oracle (testing-only)
# ===================================================================
SGLANG_KV_CANARY_RING_CAPACITY = EnvInt(1024)
SGLANG_KV_CANARY_STATS_PRINT_EVERY_N_STEPS = EnvInt(100)
SGLANG_KV_CANARY_ENABLE_WRITE_INPUT_ASSERT = EnvBool(False)
SGLANG_KV_CANARY_PERTURB_REQ_TO_TOKEN_PROB = EnvFloat(0.0)
SGLANG_KV_CANARY_PERTURB_WARMUP_STEPS = EnvInt(50)
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/kv_canary/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class CanaryConfig:
expected_tokens from each req's ``origin_input_ids + output_ids`` (snapshotted at
ForwardBatch.init_new) and compare against the canary's stored tokens at verify time.
Independent of ``enable_write_input_assert``.
stats_print_every_n_steps: 0 disables periodic stats logging; positive N prints
"canary protected N tokens, ran M sweep passes, K violations so far" every N forward steps.
"""

mode: CanaryMode
Expand All @@ -55,6 +57,7 @@ class CanaryConfig:
real_kv_hash_mode: RealKvHashMode
enable_write_input_assert: bool
enable_verify_token_assert: bool
stats_print_every_n_steps: int

@classmethod
def from_env(cls, server_args: "ServerArgs") -> "CanaryConfig":
Expand All @@ -73,4 +76,5 @@ def from_env(cls, server_args: "ServerArgs") -> "CanaryConfig":
real_kv_hash_mode=RealKvHashMode[real_kv_raw],
enable_write_input_assert=envs.SGLANG_KV_CANARY_ENABLE_WRITE_INPUT_ASSERT.get(),
enable_verify_token_assert=envs.SGLANG_KV_CANARY_ENABLE_VERIFY_TOKEN_ASSERT.get(),
stats_print_every_n_steps=envs.SGLANG_KV_CANARY_STATS_PRINT_EVERY_N_STEPS.get(),
)
20 changes: 20 additions & 0 deletions python/sglang/srt/kv_canary/runner/canary_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)
from sglang.srt.kv_canary.perturb.config import PerturbConfig
from sglang.srt.kv_canary.perturb.manager import PerturbManager
from sglang.srt.kv_canary.runner.health_checker import KernelRunCounterHealthChecker
from sglang.srt.kv_canary.runner.stats_logger import PeriodicCanaryStatsLogger
from sglang.srt.kv_canary.runner.swa_divergence import SwaDivergenceReporter
from sglang.srt.kv_canary.runner.sweep import SweepOrchestrator
from sglang.srt.kv_canary.runner.violation_manager import ViolationManager
Expand Down Expand Up @@ -127,6 +129,22 @@ def __init__(
swa_window_size=self._swa_window_size,
sweep_interval=config.sweep_interval,
)
self._health_checker = KernelRunCounterHealthChecker(
config=config,
device_state=self._device_state,
active_tags=self._active_tags,
outer_step_counter_getter=self._get_outer_step_counter,
d2h_stream=self._d2h_stream,
)
self._stats_logger = PeriodicCanaryStatsLogger(
config=config,
device_state=self._device_state,
active_tags=self._active_tags,
outer_step_counter_getter=self._get_outer_step_counter,
sweep_orchestrator=self._sweep_orchestrator,
d2h_stream=self._d2h_stream,
)

num_sfms = max(1, speculative_num_steps - 1)
self._single_forward_managers: tuple[SingleForwardManager, ...] = tuple(
SingleForwardManager(
Expand Down Expand Up @@ -233,6 +251,8 @@ def _post_ops_outside_graph(
self._sweep_orchestrator.maybe_run_sweep()
self._outer_step_counter += 1
self._violation_manager.step()
self._health_checker.step()
self._stats_logger.step()
if self._swa_divergence_report is not None:
self._swa_divergence_report.step(
outer_step_counter=self._outer_step_counter,
Expand Down
81 changes: 81 additions & 0 deletions python/sglang/srt/kv_canary/runner/health_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

import logging
from collections.abc import Callable
from typing import Optional

import torch

from sglang.jit_kernel.kv_canary.verify import CanaryLaunchTag
from sglang.srt.kv_canary.config import CanaryConfig
from sglang.srt.kv_canary.runner.future_tensor import DelayedDeviceHostHandler
from sglang.srt.kv_canary.runner.kernel_launcher import passes_v_half_gate
from sglang.srt.kv_canary.state import CanaryDeviceState

logger = logging.getLogger(__name__)

_HEALTH_CHECK_EVERY_N_STEPS: int = 100
_HEALTH_CHECK_WARMUP_STEPS: int = 100
_SWEEP_TAGS: frozenset[CanaryLaunchTag] = frozenset(
(
CanaryLaunchTag.SWEEP_K_FULL,
CanaryLaunchTag.SWEEP_V_FULL,
CanaryLaunchTag.SWEEP_K_SWA,
CanaryLaunchTag.SWEEP_V_SWA,
)
)


class KernelRunCounterHealthChecker:
def __init__(
self,
*,
config: CanaryConfig,
device_state: CanaryDeviceState,
active_tags: tuple[CanaryLaunchTag, ...],
outer_step_counter_getter: Callable[[], int],
d2h_stream: torch.cuda.Stream,
) -> None:
self._config = config
self._device_state = device_state
self._active_tags = active_tags
self._outer_step_counter_getter = outer_step_counter_getter
self._handler = DelayedDeviceHostHandler(d2h_stream=d2h_stream)
self._prev_counters_host: torch.Tensor = torch.zeros_like(
device_state.kernel_run_counters, device="cpu"
)

def step(self) -> None:
self._handler.step(
compute_on_device=self._compute_on_device,
postprocess_on_host=self._postprocess_on_host,
)

def _compute_on_device(self) -> Optional[torch.Tensor]:
outer_step_counter = self._outer_step_counter_getter()
if outer_step_counter < _HEALTH_CHECK_WARMUP_STEPS:
return None
if outer_step_counter % _HEALTH_CHECK_EVERY_N_STEPS != 0:
return None
if not self._active_tags:
return None
return self._device_state.kernel_run_counters

def _postprocess_on_host(self, new_counter_host: torch.Tensor) -> None:
delta = new_counter_host - self._prev_counters_host
self._prev_counters_host = new_counter_host
expected_tags = self._expected_active_tags_for_health_check()
stalled = [tag for tag in expected_tags if int(delta[tag.value]) == 0]
if stalled:
names = ", ".join(tag.name for tag in stalled)
raise RuntimeError(
f"kv-canary: kernel_run_counter did not increase since previous check "
f"for tags=[{names}] at step={self._outer_step_counter_getter()}; "
f"canary path is not executing"
)

def _expected_active_tags_for_health_check(self) -> tuple[CanaryLaunchTag, ...]:
tags = self._active_tags
if self._config.sweep_interval <= 0:
tags = tuple(tag for tag in tags if tag not in _SWEEP_TAGS)
return tuple(tag for tag in tags if passes_v_half_gate(tag))
66 changes: 66 additions & 0 deletions python/sglang/srt/kv_canary/runner/stats_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

import logging
from collections.abc import Callable
from typing import Any, Optional

import torch

from sglang.jit_kernel.kv_canary.verify import CanaryLaunchTag
from sglang.srt.kv_canary.config import CanaryConfig
from sglang.srt.kv_canary.runner.future_tensor import DelayedDeviceHostHandler
from sglang.srt.kv_canary.runner.sweep import SweepOrchestrator
from sglang.srt.kv_canary.state import CanaryDeviceState

logger = logging.getLogger(__name__)


class PeriodicCanaryStatsLogger:
def __init__(
self,
*,
config: CanaryConfig,
device_state: CanaryDeviceState,
active_tags: tuple[CanaryLaunchTag, ...],
outer_step_counter_getter: Callable[[], int],
sweep_orchestrator: SweepOrchestrator,
d2h_stream: torch.cuda.Stream,
) -> None:
self._config = config
self._device_state = device_state
self._active_tags = active_tags
self._outer_step_counter_getter = outer_step_counter_getter
self._sweep_orchestrator = sweep_orchestrator
self._handler = DelayedDeviceHostHandler(d2h_stream=d2h_stream)

def step(self) -> None:
self._handler.step(
compute_on_device=self._compute_on_device,
postprocess_on_host=self._postprocess_on_host,
)

def _compute_on_device(self) -> Optional[dict[str, Any]]:
period = self._config.stats_print_every_n_steps
if period <= 0:
return None
outer_step_counter = self._outer_step_counter_getter()
if outer_step_counter == 0 or outer_step_counter % period != 0:
return None
device_state = self._device_state
return {
"step": outer_step_counter,
"slot_sum": device_state.slot_run_counters.sum().view(1),
"write_index": device_state.violation_log.violation_write_index,
}

def _postprocess_on_host(self, host_data: dict[str, Any]) -> None:
logger.info(
"[canary] step=%d protected_tokens=%d sweep_passes=%d violations=%d "
"launch_tags_active=%d/%d",
int(host_data["step"]),
int(host_data["slot_sum"].item()),
self._sweep_orchestrator.sweep_passes,
int(host_data["write_index"].item()),
len(self._active_tags),
len(CanaryLaunchTag),
)
1 change: 1 addition & 0 deletions python/sglang/test/kv_canary/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def make_base_config() -> CanaryConfig:
real_kv_hash_mode=consts.RealKvHashMode.NONE,
enable_write_input_assert=False,
enable_verify_token_assert=True,
stats_print_every_n_steps=100,
)


Expand Down
2 changes: 2 additions & 0 deletions python/sglang/test/kv_canary/runner_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def make_config(
real_kv_hash_mode: RealKvHashMode = RealKvHashMode.NONE,
enable_write_input_assert: bool = False,
enable_verify_token_assert: bool = True,
stats_print_every_n_steps: int = 100,
) -> CanaryConfig:
return CanaryConfig(
mode=mode,
Expand All @@ -38,6 +39,7 @@ def make_config(
real_kv_hash_mode=real_kv_hash_mode,
enable_write_input_assert=enable_write_input_assert,
enable_verify_token_assert=enable_verify_token_assert,
stats_print_every_n_steps=stats_print_every_n_steps,
)


Expand Down
1 change: 1 addition & 0 deletions test/registered/kv_canary/test_self_unit_buffer_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _config(mode: RealKvHashMode) -> CanaryConfig:
real_kv_hash_mode=mode,
enable_write_input_assert=False,
enable_verify_token_assert=False,
stats_print_every_n_steps=100,
)


Expand Down
128 changes: 128 additions & 0 deletions test/registered/kv_canary/test_self_unit_runner_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

import logging
import unittest
from unittest.mock import Mock

import torch

from sglang.jit_kernel.kv_canary.verify import CanaryLaunchTag
from sglang.srt.kv_canary.config import CanaryConfig
from sglang.srt.kv_canary.runner import stats_logger as stats_logger_module
from sglang.srt.kv_canary.runner.health_checker import KernelRunCounterHealthChecker
from sglang.srt.kv_canary.state import CanaryDeviceState
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.kv_canary.runner_test_base import (
CanaryManagerTestCase,
make_config,
make_manager,
)
from sglang.test.test_utils import CustomTestCase

register_cuda_ci(est_time=45, stage="extra-a", runner_config="1-gpu-small")


class TestSelfUnitManagerHealth(CanaryManagerTestCase):
def test_kernel_run_counter_watchdog_raises_on_zero(self) -> None:
"""Verify the kernel watchdog raises when counters stop advancing."""
manager = make_manager(device=self.device)
manager._outer_step_counter = 1000
manager._device_state.kernel_run_counters.zero_()
manager._health_checker.step()
manager._outer_step_counter = 2000
with self.assertRaises(RuntimeError):
manager._health_checker.step()

def test_kernel_run_counter_watchdog_ignores_sweep_when_sweep_is_disabled(
self,
) -> None:
"""Verify the watchdog ignores disabled sweep counters."""
config = make_config(sweep_interval=0)
manager = make_manager(device=self.device, config=config)
manager._device_state.kernel_run_counters.zero_()
for tag in (
CanaryLaunchTag.HEAD_K_FULL,
CanaryLaunchTag.HEAD_V_FULL,
CanaryLaunchTag.TAIL_K_FULL,
CanaryLaunchTag.TAIL_V_FULL,
):
manager._device_state.kernel_run_counters[tag.value] = 1

manager._outer_step_counter = 1000
manager._health_checker.step()
manager._outer_step_counter = 2000
manager._health_checker.step()

def test_periodic_stats_log_every_n_step(self) -> None:
"""Verify periodic stats are logged at the configured interval."""
config = make_config(stats_print_every_n_steps=5)
manager = make_manager(device=self.device, config=config)
manager._device_state.slot_run_counters.fill_(7)

with self.assertLogs(stats_logger_module.logger.name, level=logging.INFO) as cm:
for _ in range(11):
manager._stats_logger.step()
manager._outer_step_counter += 1
log_text = "\n".join(cm.output)
self.assertIn("protected_tokens=", log_text)
self.assertTrue("step=5" in log_text or "step=10" in log_text)


class TestKernelRunCounterDeltaCheck(CustomTestCase):
"""Pure host-side regression tests for the watchdog's delta semantics."""

def _make_checker(
self,
*,
active_tags: tuple[CanaryLaunchTag, ...],
outer_step: int,
) -> KernelRunCounterHealthChecker:
config = Mock(spec=CanaryConfig)
config.sweep_interval = 0
num_tags = len(CanaryLaunchTag)
device_state = Mock(spec=CanaryDeviceState)
device_state.kernel_run_counters = torch.zeros(num_tags, dtype=torch.int64)
return KernelRunCounterHealthChecker(
config=config,
device_state=device_state,
active_tags=active_tags,
outer_step_counter_getter=lambda: outer_step,
d2h_stream=Mock(),
)

def _host_tensor(self, value: int) -> torch.Tensor:
return torch.full((len(CanaryLaunchTag),), value, dtype=torch.int64)

def test_first_check_raises_when_counter_never_incremented(self) -> None:
active_tags = (CanaryLaunchTag.HEAD_K_FULL, CanaryLaunchTag.TAIL_K_FULL)
checker = self._make_checker(active_tags=active_tags, outer_step=100)

host_counters = torch.zeros(len(CanaryLaunchTag), dtype=torch.int64)
with self.assertRaises(RuntimeError) as cm:
checker._postprocess_on_host(host_counters)
message = str(cm.exception)
self.assertIn(CanaryLaunchTag.HEAD_K_FULL.name, message)
self.assertIn(CanaryLaunchTag.TAIL_K_FULL.name, message)
self.assertIn("did not increase", message)

def test_second_check_raises_when_delta_is_zero(self) -> None:
active_tags = (CanaryLaunchTag.HEAD_K_FULL,)
checker = self._make_checker(active_tags=active_tags, outer_step=100)

checker._postprocess_on_host(self._host_tensor(1))
with self.assertRaises(RuntimeError) as cm:
checker._postprocess_on_host(self._host_tensor(1))
message = str(cm.exception)
self.assertIn(CanaryLaunchTag.HEAD_K_FULL.name, message)
self.assertIn("did not increase", message)

def test_no_raise_when_counter_increases(self) -> None:
active_tags = (CanaryLaunchTag.HEAD_K_FULL, CanaryLaunchTag.TAIL_K_FULL)
checker = self._make_checker(active_tags=active_tags, outer_step=100)

for value in (1, 5, 12, 100):
checker._postprocess_on_host(self._host_tensor(value))


if __name__ == "__main__":
unittest.main()
Loading