Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
GetWeightsChecksumReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterFromTensorsReqInput,
LoadLoRAAdapterReqInput,
Expand Down Expand Up @@ -642,6 +643,16 @@ def get_weights_by_name(self, name: str, truncate_size: int = 100):
self.tokenizer_manager.get_weights_by_name(obj, None)
)

def get_weights_checksum(self):
"""Get model weights checksum from backend.

Note: with TP > 1, this is a shard-local checksum, not a gathered full-model checksum.
"""
obj = GetWeightsChecksumReqInput()
return self.loop.run_until_complete(
self.tokenizer_manager.get_weights_checksum(obj, None)
)

def load_lora_adapter_from_tensors(
self, lora_name: str, tensors: List[Tuple[str, torch.Tensor]], config_dict: Dict
):
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,16 @@ class GetWeightsByNameReqOutput(BaseReq):
parameter: list


@dataclass
class GetWeightsChecksumReqInput(BaseReq):
pass


@dataclass
class GetWeightsChecksumReqOutput(BaseReq):
checksum: str


@dataclass
class ReleaseMemoryOccupationReqInput(BaseReq):
# Optional tags to identify the memory region, which is primarily used for RL
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
GetLoadReqInput,
GetLoadsReqInput,
GetWeightsByNameReqInput,
GetWeightsChecksumReqInput,
HealthCheckOutput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsSendGroupForRemoteInstanceReqOutput,
Expand Down Expand Up @@ -1053,6 +1054,7 @@ def init_request_dispatcher(self):
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(GetWeightsChecksumReqInput, self.get_weights_checksum),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(CheckWeightsReqInput, self.check_weights),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
DestroyWeightsUpdateGroupReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
GetWeightsChecksumReqInput,
GetWeightsChecksumReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
ReleaseMemoryOccupationReqInput,
Expand Down Expand Up @@ -118,6 +120,10 @@ def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)

def get_weights_checksum(self: Scheduler, recv_req: GetWeightsChecksumReqInput):
checksum = self.tp_worker.get_weights_checksum(recv_req)
return GetWeightsChecksumReqOutput(checksum)

def release_memory_occupation(
self: Scheduler, recv_req: ReleaseMemoryOccupationReqInput
):
Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/managers/tokenizer_communicator_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
GetLoadsReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
GetWeightsChecksumReqInput,
GetWeightsChecksumReqOutput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsSendGroupForRemoteInstanceReqOutput,
InitWeightsUpdateGroupReqInput,
Expand Down Expand Up @@ -188,6 +190,9 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs):
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_checksum_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.release_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
Expand Down Expand Up @@ -271,6 +276,10 @@ def _get_communicator_dispatcher(self: TokenizerManager):
GetWeightsByNameReqOutput,
self.get_weights_by_name_communicator.handle_recv,
),
(
GetWeightsChecksumReqOutput,
self.get_weights_checksum_communicator.handle_recv,
),
(
ReleaseMemoryOccupationReqOutput,
self.release_memory_occupation_communicator.handle_recv,
Expand Down Expand Up @@ -812,6 +821,19 @@ async def get_weights_by_name(
else:
return all_parameters

async def get_weights_checksum(
self: TokenizerManager,
obj: GetWeightsChecksumReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
results = await self.get_weights_checksum_communicator(obj)
all_checksums = [r.checksum for r in results]
if self.server_args.dp_size == 1:
return all_checksums[0]
else:
return all_checksums

async def release_memory_occupation(
self: TokenizerManager,
obj: ReleaseMemoryOccupationReqInput,
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
GetWeightsByNameReqInput,
GetWeightsChecksumReqInput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterFromTensorsReqInput,
Expand Down Expand Up @@ -174,6 +175,9 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
)
return parameter

def get_weights_checksum(self, recv_req: GetWeightsChecksumReqInput):
return self.model_runner.get_weights_checksum()

def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
return result
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
)
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils.weight_checker import WeightChecker
from sglang.srt.utils.weight_checksum import compute_weights_checksum
from sglang.srt.weight_sync.tensor_bucket import (
FlattenedTensorBucket,
FlattenedTensorMetadata,
Expand Down Expand Up @@ -1499,6 +1500,10 @@ def get_weights_by_name(
logger.error(f"Error when getting parameter {name}: {e}")
return None

def get_weights_checksum(self):
"""Compute SHA-256 checksum of parameters local to this ModelRunner (current TP rank)."""
return compute_weights_checksum(self.model.named_parameters())

def init_lora_manager(self):
self.lora_manager = LoRAManager(
base_model=self.model,
Expand Down
17 changes: 17 additions & 0 deletions python/sglang/srt/utils/weight_checksum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import hashlib

import torch
from torch.distributed.tensor import DTensor


def compute_weights_checksum(named_params):
"""Compute a single SHA-256 hash over all weights, sorted by name for determinism."""
hasher = hashlib.sha256()
for name, tensor in sorted(named_params, key=lambda x: x[0]):
hasher.update(name.encode())
t = tensor.detach()
# DTensor doesn't support .numpy(); extract the local tensor.
if isinstance(t, DTensor):
t = t._local_tensor
Comment on lines +14 to +15
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using _local_tensor accesses a private attribute of DTensor, which could be unstable across different PyTorch versions. It's better to use the public API to_local() to get the local tensor shard. This will improve code maintainability and stability.

Suggested change
if isinstance(t, DTensor):
t = t._local_tensor
if isinstance(t, DTensor):
t = t.to_local()

hasher.update(t.cpu().contiguous().reshape(-1).view(torch.uint8).numpy().data)
return hasher.hexdigest()
Loading
Loading