Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,11 +1295,13 @@ async def resume_memory_occupation(
@app.post("/weights_checker")
@auth_level(AuthLevel.ADMIN_OPTIONAL)
async def check_weights(obj: CheckWeightsReqInput, request: Request):
success, message = await _global_state.tokenizer_manager.check_weights(obj, request)
return ORJSONResponse(
{"success": success, "message": message},
status_code=200 if success else HTTPStatus.BAD_REQUEST,
success, message, ranks = await _global_state.tokenizer_manager.check_weights(
obj, request
)
body = {"success": success, "message": message}
if ranks is not None:
body["ranks"] = ranks
return ORJSONResponse(body, status_code=200 if success else HTTPStatus.BAD_REQUEST)


@app.api_route("/slow_down", methods=["GET", "POST"])
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,7 @@ class CheckWeightsReqInput(BaseReq):
class CheckWeightsReqOutput(BaseReq):
success: bool
message: str
payload: Optional[Dict] = None


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,10 @@ def resume_memory_occupation(

def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput):
try:
self.tp_worker.model_runner.check_weights(action=recv_req.action)
return CheckWeightsReqOutput(success=True, message="Success.")
payload = self.tp_worker.model_runner.check_weights(action=recv_req.action)
return CheckWeightsReqOutput(
success=True, message="Success.", payload=payload
)
except Exception as e:
logger.warning(f"check_weights see error: {e}")
traceback.print_exc()
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/managers/tokenizer_control_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,10 +764,14 @@ async def check_weights(
self: TokenizerManager,
obj: CheckWeightsReqInput,
request: Optional[fastapi.Request] = None,
) -> CheckWeightsReqOutput:
) -> Tuple[bool, str, Optional[List[Dict]]]:
self.auto_create_handle_loop()
results = await self.check_weights_communicator(obj)
return FanOutCommunicator.merge_results(results)
success, message = FanOutCommunicator.merge_results(results)
ranks: Optional[List[Dict]] = None
if any(r.payload is not None for r in results):
ranks = [r.payload for r in results]
return success, message, ranks

async def slow_down(
self: TokenizerManager,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3484,7 +3484,7 @@ def save_sharded_model(
ShardedStateLoader.save_model(self.model, path, pattern, max_size)

def check_weights(self, action: str):
self._weight_checker.handle(action=action)
return self._weight_checker.handle(action=action)

def update_weights_from_ipc(self, recv_req):
"""Update weights from IPC for checkpoint-engine integration."""
Expand Down
80 changes: 75 additions & 5 deletions python/sglang/srt/utils/weight_checker.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,40 @@
import logging
from typing import Dict, Iterable, Tuple
import time
from typing import Dict, Iterable, Optional, Tuple

import torch
import torch.distributed as dist
from pydantic import BaseModel, ConfigDict

from sglang.srt.layers.quantization.fp8_utils import (
block_quant_dequant,
inverse_transform_scale_ue8m0,
)
from sglang.srt.managers.mm_utils import tensor_hash

logger = logging.getLogger(__name__)


class _StrictBaseModel(BaseModel):
model_config = ConfigDict(extra="forbid")


class ParallelismInfo(_StrictBaseModel):
tp_rank: int
tp_size: int
dp_rank: int
dp_size: int
pp_rank: int
pp_size: int
rank: int
size: int


class ChecksumInfo(_StrictBaseModel):
checksums: Dict[str, str]
parallelism_info: ParallelismInfo


_NON_PERSISTENT_BUFFER_PATTERNS = (
"cos_sin_cache",
"inv_freq",
Expand All @@ -28,14 +52,16 @@ def __init__(self, model_runner):
self._model_runner = model_runner
self._snapshot_tensors = None

def handle(self, action: str):
def handle(self, action: str) -> Optional[Dict]:
logger.info(f"[WeightChecker] handle action={action}")
if action == "snapshot":
self._snapshot()
return self._snapshot()
elif action == "reset_tensors":
self._reset_tensors()
return self._reset_tensors()
elif action == "compare":
self._compare()
return self._compare()
elif action == "checksum":
return self._compute_checksum()
else:
raise Exception(f"Unsupported {action=}")

Expand All @@ -62,12 +88,56 @@ def _compare(self):
actual_tensors=_postprocess_tensors(dict(self._model_state())),
)

def _compute_checksum(self) -> Dict:
torch.cuda.synchronize()
start = time.perf_counter()

# Reuse the snapshot/compare postprocess pipeline so fp8 weights are
# dequantized to bf16 before hashing — two (qweight, scale) pairs that
# produce the same bf16 must produce the same checksum.
checksums = {
name: _hash_tensor(tensor.data)
for name, should_compare, tensor in _postprocess_tensors(
dict(self._model_state())
)
if should_compare
}

torch.cuda.synchronize()
elapsed = time.perf_counter() - start
logger.info(
f"[WeightChecker] checksum computed for {len(checksums)} tensors in {elapsed:.3f}s"
)

info = ChecksumInfo(
checksums=checksums,
parallelism_info=self._parallelism_info(),
)
return info.model_dump()

def _parallelism_info(self) -> ParallelismInfo:
mr = self._model_runner
return ParallelismInfo(
tp_rank=mr.tp_rank,
tp_size=mr.tp_size,
dp_rank=mr.dp_rank if mr.dp_rank is not None else 0,
dp_size=mr.dp_size,
pp_rank=mr.pp_rank,
pp_size=mr.pp_size,
rank=dist.get_rank() if dist.is_initialized() else 0,
size=dist.get_world_size() if dist.is_initialized() else 1,
)

def _model_state(self):
# TODO: support EAGLE etc (e.g. yield from both main model and draft model)
yield from self._model_runner.model.named_parameters()
yield from self._model_runner.model.named_buffers()


def _hash_tensor(t: torch.Tensor) -> str:
return f"{tensor_hash(t):016x}"


def _check_tensors(
expect_tensors: Iterable[Tuple[str, bool, torch.Tensor]],
actual_tensors: Iterable[Tuple[str, bool, torch.Tensor]],
Expand Down
68 changes: 68 additions & 0 deletions test/registered/rl/test_weight_checker_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,74 @@ def test_d_update_with_same_tensor_keeps_compare_passing(self):
self.assertEqual(resp.status_code, 200)
self.assertTrue(resp.json()["success"])

def test_e_checksum_returns_ranks_with_hashes(self):
"""checksum action must yield a ranks list with hex hashes per rank."""
resp = self._post("checksum")
self.assertEqual(resp.status_code, 200)
body = resp.json()
self.assertTrue(body["success"])
self.assertIn("ranks", body)
ranks = body["ranks"]
self.assertIsInstance(ranks, list)
self.assertGreaterEqual(len(ranks), 1)

first = ranks[0]
self.assertIn("checksums", first)
self.assertIn("parallelism_info", first)

info = first["parallelism_info"]
for key in (
"tp_rank",
"tp_size",
"dp_rank",
"dp_size",
"pp_rank",
"pp_size",
"rank",
"size",
):
self.assertIn(key, info)

checksums = first["checksums"]
self.assertGreater(len(checksums), 0)
for name, h in checksums.items():
self.assertIsInstance(h, str)
self.assertEqual(len(h), 16, f"unexpected hash length for {name!r}: {h!r}")
int(h, 16)

def test_e_checksum_is_stable_across_calls(self):
"""Two consecutive checksum calls with no weight update must match."""
first = self._post("checksum").json()["ranks"]
second = self._post("checksum").json()["ranks"]
self.assertEqual(first, second)

def test_e_checksum_changes_after_weight_update(self):
"""Updating a tensor must change its corresponding hash."""
param_name = "model.layers.7.mlp.up_proj.weight"
fused_name = "model.layers.7.mlp.gate_up_proj.weight"

before = self._post("checksum").json()["ranks"][0]["checksums"]
before_hash = before.get(fused_name)
self.assertIsNotNone(before_hash, f"missing {fused_name!r} in checksum keys")

new_tensor = torch.full(_UP_PROJ_SHAPE, 0.5, device="cuda")
self.assertTrue(
self._update_weights([(param_name, new_tensor)]).json()["success"]
)

after = self._post("checksum").json()["ranks"][0]["checksums"]
self.assertNotEqual(after[fused_name], before_hash)

def test_e_checksum_skips_non_persistent_buffers(self):
"""No checksum entry should contain a non-persistent-buffer substring."""
ranks = self._post("checksum").json()["ranks"]
for rank in ranks:
for name in rank["checksums"]:
self.assertNotIn("cos_sin_cache", name)
self.assertNotIn("inv_freq", name)
self.assertNotIn("freqs_cis", name)
self.assertNotIn("_weight_fp32", name)

def test_z_snapshot_reset_compare_detects_diff(self):
"""Destructive: leaves weights randomized. Named test_z_* so it runs last."""
self.assertEqual(self._post("snapshot").status_code, 200)
Expand Down
Loading
Loading