diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index bcf8a28a46f3..f347b95ed060 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1295,11 +1295,13 @@ async def resume_memory_occupation( @app.post("/weights_checker") @auth_level(AuthLevel.ADMIN_OPTIONAL) async def check_weights(obj: CheckWeightsReqInput, request: Request): - success, message = await _global_state.tokenizer_manager.check_weights(obj, request) - return ORJSONResponse( - {"success": success, "message": message}, - status_code=200 if success else HTTPStatus.BAD_REQUEST, + success, message, ranks = await _global_state.tokenizer_manager.check_weights( + obj, request ) + body = {"success": success, "message": message} + if ranks is not None: + body["ranks"] = ranks + return ORJSONResponse(body, status_code=200 if success else HTTPStatus.BAD_REQUEST) @app.api_route("/slow_down", methods=["GET", "POST"]) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index a573e672a8ae..71391bd6d3da 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1625,6 +1625,7 @@ class CheckWeightsReqInput(BaseReq): class CheckWeightsReqOutput(BaseReq): success: bool message: str + payload: Optional[Dict] = None @dataclass diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index bfb6f084a7b2..590537fd6bb6 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -196,8 +196,10 @@ def resume_memory_occupation( def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): try: - self.tp_worker.model_runner.check_weights(action=recv_req.action) - return CheckWeightsReqOutput(success=True, message="Success.") + payload = self.tp_worker.model_runner.check_weights(action=recv_req.action) + return CheckWeightsReqOutput( + success=True, message="Success.", payload=payload + ) except Exception as e: logger.warning(f"check_weights see error: {e}") traceback.print_exc() diff --git a/python/sglang/srt/managers/tokenizer_control_mixin.py b/python/sglang/srt/managers/tokenizer_control_mixin.py index c99999f4b38e..05382e073eda 100644 --- a/python/sglang/srt/managers/tokenizer_control_mixin.py +++ b/python/sglang/srt/managers/tokenizer_control_mixin.py @@ -764,10 +764,14 @@ async def check_weights( self: TokenizerManager, obj: CheckWeightsReqInput, request: Optional[fastapi.Request] = None, - ) -> CheckWeightsReqOutput: + ) -> Tuple[bool, str, Optional[List[Dict]]]: self.auto_create_handle_loop() results = await self.check_weights_communicator(obj) - return FanOutCommunicator.merge_results(results) + success, message = FanOutCommunicator.merge_results(results) + ranks: Optional[List[Dict]] = None + if any(r.payload is not None for r in results): + ranks = [r.payload for r in results] + return success, message, ranks async def slow_down( self: TokenizerManager, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7bdd8ffd686c..97ef0fdf2cad 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -3484,7 +3484,7 @@ def save_sharded_model( ShardedStateLoader.save_model(self.model, path, pattern, max_size) def check_weights(self, action: str): - self._weight_checker.handle(action=action) + return self._weight_checker.handle(action=action) def update_weights_from_ipc(self, recv_req): """Update weights from IPC for checkpoint-engine integration.""" diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index fa5d97ebdb62..6fd74a4e05f2 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -1,16 +1,40 @@ import logging -from typing import Dict, Iterable, Tuple +import time +from typing import Dict, Iterable, Optional, Tuple import torch +import torch.distributed as dist +from pydantic import BaseModel, ConfigDict from sglang.srt.layers.quantization.fp8_utils import ( block_quant_dequant, inverse_transform_scale_ue8m0, ) +from sglang.srt.managers.mm_utils import tensor_hash logger = logging.getLogger(__name__) +class _StrictBaseModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class ParallelismInfo(_StrictBaseModel): + tp_rank: int + tp_size: int + dp_rank: int + dp_size: int + pp_rank: int + pp_size: int + rank: int + size: int + + +class ChecksumInfo(_StrictBaseModel): + checksums: Dict[str, str] + parallelism_info: ParallelismInfo + + _NON_PERSISTENT_BUFFER_PATTERNS = ( "cos_sin_cache", "inv_freq", @@ -28,14 +52,16 @@ def __init__(self, model_runner): self._model_runner = model_runner self._snapshot_tensors = None - def handle(self, action: str): + def handle(self, action: str) -> Optional[Dict]: logger.info(f"[WeightChecker] handle action={action}") if action == "snapshot": - self._snapshot() + return self._snapshot() elif action == "reset_tensors": - self._reset_tensors() + return self._reset_tensors() elif action == "compare": - self._compare() + return self._compare() + elif action == "checksum": + return self._compute_checksum() else: raise Exception(f"Unsupported {action=}") @@ -62,12 +88,56 @@ def _compare(self): actual_tensors=_postprocess_tensors(dict(self._model_state())), ) + def _compute_checksum(self) -> Dict: + torch.cuda.synchronize() + start = time.perf_counter() + + # Reuse the snapshot/compare postprocess pipeline so fp8 weights are + # dequantized to bf16 before hashing — two (qweight, scale) pairs that + # produce the same bf16 must produce the same checksum. + checksums = { + name: _hash_tensor(tensor.data) + for name, should_compare, tensor in _postprocess_tensors( + dict(self._model_state()) + ) + if should_compare + } + + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + logger.info( + f"[WeightChecker] checksum computed for {len(checksums)} tensors in {elapsed:.3f}s" + ) + + info = ChecksumInfo( + checksums=checksums, + parallelism_info=self._parallelism_info(), + ) + return info.model_dump() + + def _parallelism_info(self) -> ParallelismInfo: + mr = self._model_runner + return ParallelismInfo( + tp_rank=mr.tp_rank, + tp_size=mr.tp_size, + dp_rank=mr.dp_rank if mr.dp_rank is not None else 0, + dp_size=mr.dp_size, + pp_rank=mr.pp_rank, + pp_size=mr.pp_size, + rank=dist.get_rank() if dist.is_initialized() else 0, + size=dist.get_world_size() if dist.is_initialized() else 1, + ) + def _model_state(self): # TODO: support EAGLE etc (e.g. yield from both main model and draft model) yield from self._model_runner.model.named_parameters() yield from self._model_runner.model.named_buffers() +def _hash_tensor(t: torch.Tensor) -> str: + return f"{tensor_hash(t):016x}" + + def _check_tensors( expect_tensors: Iterable[Tuple[str, bool, torch.Tensor]], actual_tensors: Iterable[Tuple[str, bool, torch.Tensor]], diff --git a/test/registered/rl/test_weight_checker_e2e.py b/test/registered/rl/test_weight_checker_e2e.py index 85df5a606fef..0ac00fcb3512 100644 --- a/test/registered/rl/test_weight_checker_e2e.py +++ b/test/registered/rl/test_weight_checker_e2e.py @@ -138,6 +138,74 @@ def test_d_update_with_same_tensor_keeps_compare_passing(self): self.assertEqual(resp.status_code, 200) self.assertTrue(resp.json()["success"]) + def test_e_checksum_returns_ranks_with_hashes(self): + """checksum action must yield a ranks list with hex hashes per rank.""" + resp = self._post("checksum") + self.assertEqual(resp.status_code, 200) + body = resp.json() + self.assertTrue(body["success"]) + self.assertIn("ranks", body) + ranks = body["ranks"] + self.assertIsInstance(ranks, list) + self.assertGreaterEqual(len(ranks), 1) + + first = ranks[0] + self.assertIn("checksums", first) + self.assertIn("parallelism_info", first) + + info = first["parallelism_info"] + for key in ( + "tp_rank", + "tp_size", + "dp_rank", + "dp_size", + "pp_rank", + "pp_size", + "rank", + "size", + ): + self.assertIn(key, info) + + checksums = first["checksums"] + self.assertGreater(len(checksums), 0) + for name, h in checksums.items(): + self.assertIsInstance(h, str) + self.assertEqual(len(h), 16, f"unexpected hash length for {name!r}: {h!r}") + int(h, 16) + + def test_e_checksum_is_stable_across_calls(self): + """Two consecutive checksum calls with no weight update must match.""" + first = self._post("checksum").json()["ranks"] + second = self._post("checksum").json()["ranks"] + self.assertEqual(first, second) + + def test_e_checksum_changes_after_weight_update(self): + """Updating a tensor must change its corresponding hash.""" + param_name = "model.layers.7.mlp.up_proj.weight" + fused_name = "model.layers.7.mlp.gate_up_proj.weight" + + before = self._post("checksum").json()["ranks"][0]["checksums"] + before_hash = before.get(fused_name) + self.assertIsNotNone(before_hash, f"missing {fused_name!r} in checksum keys") + + new_tensor = torch.full(_UP_PROJ_SHAPE, 0.5, device="cuda") + self.assertTrue( + self._update_weights([(param_name, new_tensor)]).json()["success"] + ) + + after = self._post("checksum").json()["ranks"][0]["checksums"] + self.assertNotEqual(after[fused_name], before_hash) + + def test_e_checksum_skips_non_persistent_buffers(self): + """No checksum entry should contain a non-persistent-buffer substring.""" + ranks = self._post("checksum").json()["ranks"] + for rank in ranks: + for name in rank["checksums"]: + self.assertNotIn("cos_sin_cache", name) + self.assertNotIn("inv_freq", name) + self.assertNotIn("freqs_cis", name) + self.assertNotIn("_weight_fp32", name) + def test_z_snapshot_reset_compare_detects_diff(self): """Destructive: leaves weights randomized. Named test_z_* so it runs last.""" self.assertEqual(self._post("snapshot").status_code, 200) diff --git a/test/registered/unit/utils/test_weight_checker.py b/test/registered/unit/utils/test_weight_checker.py index ca7b75102622..17ec362719df 100644 --- a/test/registered/unit/utils/test_weight_checker.py +++ b/test/registered/unit/utils/test_weight_checker.py @@ -26,8 +26,12 @@ transform_scale_ue8m0, ) from sglang.srt.utils.weight_checker import ( + ChecksumInfo, + ParallelismInfo, WeightChecker, _check_tensors, + _hash_tensor, + _is_non_persistent_buffer_name, _postprocess_tensors, _random_like, ) @@ -99,11 +103,26 @@ def __init__(self): class _FakeModelRunner: - """Minimal stand-in: WeightChecker only touches `.model.named_parameters()` and - `.model.named_buffers()`, nothing else.""" - - def __init__(self, model: nn.Module): + """Minimal stand-in: WeightChecker touches `.model.named_parameters()`, + `.model.named_buffers()`, plus parallelism attributes for the checksum action.""" + + def __init__( + self, + model: nn.Module, + tp_rank: int = 0, + tp_size: int = 1, + dp_rank: int = 0, + dp_size: int = 1, + pp_rank: int = 0, + pp_size: int = 1, + ): self.model = model + self.tp_rank = tp_rank + self.tp_size = tp_size + self.dp_rank = dp_rank + self.dp_size = dp_size + self.pp_rank = pp_rank + self.pp_size = pp_size # --------------------------------------------------------------------------- @@ -449,13 +468,27 @@ class TestHandle(_WeightCheckerTestBase): def test_routes_to_actions(self): with patch.object(self.checker, "_snapshot") as m_snap, patch.object( self.checker, "_reset_tensors" - ) as m_reset, patch.object(self.checker, "_compare") as m_compare: + ) as m_reset, patch.object(self.checker, "_compare") as m_compare, patch.object( + self.checker, "_compute_checksum", return_value={"checksums": {}} + ) as m_checksum: self.checker.handle("snapshot") self.checker.handle("reset_tensors") self.checker.handle("compare") + self.checker.handle("checksum") m_snap.assert_called_once() m_reset.assert_called_once() m_compare.assert_called_once() + m_checksum.assert_called_once() + + def test_returns_none_for_non_checksum_actions(self): + self.assertIsNone(self.checker.handle("snapshot")) + self.assertIsNone(self.checker.handle("compare")) + + def test_returns_dict_for_checksum_action(self): + out = self.checker.handle("checksum") + self.assertIsInstance(out, dict) + self.assertIn("checksums", out) + self.assertIn("parallelism_info", out) def test_unknown_action_raises(self): with self.assertRaises(Exception) as ctx: @@ -463,5 +496,138 @@ def test_unknown_action_raises(self): self.assertIn("Unsupported", str(ctx.exception)) +# --------------------------------------------------------------------------- +# _is_non_persistent_buffer_name +# --------------------------------------------------------------------------- + + +class TestIsNonPersistentBufferName(CustomTestCase): + + def test_matches_cos_sin_cache_substring(self): + self.assertTrue( + _is_non_persistent_buffer_name("model.rotary_emb.cos_sin_cache") + ) + + def test_matches_inv_freq_substring(self): + self.assertTrue(_is_non_persistent_buffer_name("model.rotary_emb.inv_freq")) + + def test_matches_freqs_cis_substring(self): + self.assertTrue(_is_non_persistent_buffer_name("model.rotary_emb.freqs_cis")) + + def test_matches_weight_fp32_substring(self): + self.assertTrue( + _is_non_persistent_buffer_name("model.layers.0.mlp.gate._weight_fp32") + ) + + def test_does_not_match_normal_param_names(self): + self.assertFalse(_is_non_persistent_buffer_name("model.layers.0.mlp.weight")) + self.assertFalse(_is_non_persistent_buffer_name("model.embed_tokens.weight")) + + +# --------------------------------------------------------------------------- +# _hash_tensor +# --------------------------------------------------------------------------- + + +class TestHashTensor(CustomTestCase): + + def test_stable_for_same_input(self): + t = torch.arange(64, dtype=torch.float32).cuda() + self.assertEqual(_hash_tensor(t), _hash_tensor(t.clone())) + + def test_changes_with_data(self): + a = torch.zeros(64, dtype=torch.float32).cuda() + b = torch.ones(64, dtype=torch.float32).cuda() + self.assertNotEqual(_hash_tensor(a), _hash_tensor(b)) + + def test_returns_16_char_hex(self): + t = torch.zeros(64, dtype=torch.float32).cuda() + h = _hash_tensor(t) + self.assertEqual(len(h), 16) + int(h, 16) # raises if not hex + + def test_does_not_mutate_input(self): + t = torch.arange(64, dtype=torch.float32).cuda() + before = t.clone() + _hash_tensor(t) + torch.testing.assert_close(t, before) + + +# --------------------------------------------------------------------------- +# _compute_checksum +# --------------------------------------------------------------------------- + + +class _ChecksumTestBase(CustomTestCase): + + def setUp(self): + torch.manual_seed(0) + self.model = _TinyModel().cuda() + self.runner = _FakeModelRunner( + self.model, + tp_rank=2, + tp_size=4, + dp_rank=1, + dp_size=2, + pp_rank=0, + pp_size=1, + ) + self.checker = WeightChecker(model_runner=self.runner) + + +class TestComputeChecksum(_ChecksumTestBase): + + def test_returns_dict_with_expected_top_level_keys(self): + out = self.checker._compute_checksum() + self.assertEqual(set(out.keys()), {"checksums", "parallelism_info"}) + + def test_skips_non_persistent_buffers(self): + out = self.checker._compute_checksum() + names = set(out["checksums"].keys()) + # Normal params and buffers are present. + self.assertIn("w", names) + self.assertIn("b", names) + self.assertIn("running_mean", names) + # Non-persistent buffer patterns are filtered out. + self.assertNotIn("rotary_emb_cos_sin_cache", names) + self.assertNotIn("rotary_emb_freqs_cis", names) + self.assertNotIn("gate_proj_weight_fp32_cache", names) + + def test_hashes_are_hex_strings(self): + out = self.checker._compute_checksum() + for name, h in out["checksums"].items(): + self.assertEqual(len(h), 16, f"unexpected hash length for {name!r}") + int(h, 16) + + def test_parallelism_info_reflects_runner_state(self): + info = self.checker._compute_checksum()["parallelism_info"] + self.assertEqual(info["tp_rank"], 2) + self.assertEqual(info["tp_size"], 4) + self.assertEqual(info["dp_rank"], 1) + self.assertEqual(info["dp_size"], 2) + self.assertEqual(info["pp_rank"], 0) + self.assertEqual(info["pp_size"], 1) + # rank/size come from torch.distributed; default to 0/1 when uninitialized. + self.assertIn("rank", info) + self.assertIn("size", info) + + def test_checksum_is_stable_for_unchanged_weights(self): + first = self.checker._compute_checksum() + second = self.checker._compute_checksum() + self.assertEqual(first, second) + + def test_checksum_changes_after_param_mutation(self): + first = self.checker._compute_checksum()["checksums"]["w"] + with torch.no_grad(): + self.model.w.data.fill_(99.0) + second = self.checker._compute_checksum()["checksums"]["w"] + self.assertNotEqual(first, second) + + def test_validates_against_pydantic_schema(self): + out = self.checker._compute_checksum() + info = ChecksumInfo.model_validate(out) + self.assertIsInstance(info.parallelism_info, ParallelismInfo) + + if __name__ == "__main__": unittest.main()