diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index b5324d8b575e..8c4a6ddf18d9 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -47,6 +47,7 @@ EmbeddingReqInput, GenerateReqInput, GetWeightsByNameReqInput, + GetWeightsChecksumReqInput, InitWeightsUpdateGroupReqInput, LoadLoRAAdapterFromTensorsReqInput, LoadLoRAAdapterReqInput, @@ -642,6 +643,16 @@ def get_weights_by_name(self, name: str, truncate_size: int = 100): self.tokenizer_manager.get_weights_by_name(obj, None) ) + def get_weights_checksum(self): + """Get model weights checksum from backend. + + Note: with TP > 1, this is a shard-local checksum, not a gathered full-model checksum. + """ + obj = GetWeightsChecksumReqInput() + return self.loop.run_until_complete( + self.tokenizer_manager.get_weights_checksum(obj, None) + ) + def load_lora_adapter_from_tensors( self, lora_name: str, tensors: List[Tuple[str, torch.Tensor]], config_dict: Dict ): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c80b54be5ccc..aa515d92c3be 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1475,6 +1475,16 @@ class GetWeightsByNameReqOutput(BaseReq): parameter: list +@dataclass +class GetWeightsChecksumReqInput(BaseReq): + pass + + +@dataclass +class GetWeightsChecksumReqOutput(BaseReq): + checksum: str + + @dataclass class ReleaseMemoryOccupationReqInput(BaseReq): # Optional tags to identify the memory region, which is primarily used for RL diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2edec253f824..b3f01e502495 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -98,6 +98,7 @@ GetLoadReqInput, GetLoadsReqInput, GetWeightsByNameReqInput, + GetWeightsChecksumReqInput, HealthCheckOutput, InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsSendGroupForRemoteInstanceReqOutput, @@ -1053,6 +1054,7 @@ def init_request_dispatcher(self): (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc), (GetWeightsByNameReqInput, self.get_weights_by_name), + (GetWeightsChecksumReqInput, self.get_weights_checksum), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), (CheckWeightsReqInput, self.check_weights), diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 293a843508b0..7ed427cbb9fd 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -19,6 +19,8 @@ DestroyWeightsUpdateGroupReqOutput, GetWeightsByNameReqInput, GetWeightsByNameReqOutput, + GetWeightsChecksumReqInput, + GetWeightsChecksumReqOutput, InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqOutput, ReleaseMemoryOccupationReqInput, @@ -118,6 +120,10 @@ def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) return GetWeightsByNameReqOutput(parameter) + def get_weights_checksum(self: Scheduler, recv_req: GetWeightsChecksumReqInput): + checksum = self.tp_worker.get_weights_checksum(recv_req) + return GetWeightsChecksumReqOutput(checksum) + def release_memory_occupation( self: Scheduler, recv_req: ReleaseMemoryOccupationReqInput ): diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index f2f9791e61a0..f6e886fd3da3 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -47,6 +47,8 @@ GetLoadsReqOutput, GetWeightsByNameReqInput, GetWeightsByNameReqOutput, + GetWeightsChecksumReqInput, + GetWeightsChecksumReqOutput, InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsSendGroupForRemoteInstanceReqOutput, InitWeightsUpdateGroupReqInput, @@ -188,6 +190,9 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs): self.get_weights_by_name_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.get_weights_checksum_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.release_memory_occupation_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -271,6 +276,10 @@ def _get_communicator_dispatcher(self: TokenizerManager): GetWeightsByNameReqOutput, self.get_weights_by_name_communicator.handle_recv, ), + ( + GetWeightsChecksumReqOutput, + self.get_weights_checksum_communicator.handle_recv, + ), ( ReleaseMemoryOccupationReqOutput, self.release_memory_occupation_communicator.handle_recv, @@ -812,6 +821,19 @@ async def get_weights_by_name( else: return all_parameters + async def get_weights_checksum( + self: TokenizerManager, + obj: GetWeightsChecksumReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + results = await self.get_weights_checksum_communicator(obj) + all_checksums = [r.checksum for r in results] + if self.server_args.dp_size == 1: + return all_checksums[0] + else: + return all_checksums + async def release_memory_occupation( self: TokenizerManager, obj: ReleaseMemoryOccupationReqInput, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 86b009df4e64..1157ee7958ec 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -25,6 +25,7 @@ from sglang.srt.managers.io_struct import ( DestroyWeightsUpdateGroupReqInput, GetWeightsByNameReqInput, + GetWeightsChecksumReqInput, InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsUpdateGroupReqInput, LoadLoRAAdapterFromTensorsReqInput, @@ -174,6 +175,9 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): ) return parameter + def get_weights_checksum(self, recv_req: GetWeightsChecksumReqInput): + return self.model_runner.get_weights_checksum() + def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput): result = self.model_runner.load_lora_adapter(recv_req.to_ref()) return result diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index eb4a4c18a133..53cd7d86ce9e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -181,6 +181,7 @@ ) from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils.weight_checker import WeightChecker +from sglang.srt.utils.weight_checksum import compute_weights_checksum from sglang.srt.weight_sync.tensor_bucket import ( FlattenedTensorBucket, FlattenedTensorMetadata, @@ -1499,6 +1500,10 @@ def get_weights_by_name( logger.error(f"Error when getting parameter {name}: {e}") return None + def get_weights_checksum(self): + """Compute SHA-256 checksum of parameters local to this ModelRunner (current TP rank).""" + return compute_weights_checksum(self.model.named_parameters()) + def init_lora_manager(self): self.lora_manager = LoRAManager( base_model=self.model, diff --git a/python/sglang/srt/utils/weight_checksum.py b/python/sglang/srt/utils/weight_checksum.py new file mode 100644 index 000000000000..b4c42d191871 --- /dev/null +++ b/python/sglang/srt/utils/weight_checksum.py @@ -0,0 +1,17 @@ +import hashlib + +import torch +from torch.distributed.tensor import DTensor + + +def compute_weights_checksum(named_params): + """Compute a single SHA-256 hash over all weights, sorted by name for determinism.""" + hasher = hashlib.sha256() + for name, tensor in sorted(named_params, key=lambda x: x[0]): + hasher.update(name.encode()) + t = tensor.detach() + # DTensor doesn't support .numpy(); extract the local tensor. + if isinstance(t, DTensor): + t = t._local_tensor + hasher.update(t.cpu().contiguous().reshape(-1).view(torch.uint8).numpy().data) + return hasher.hexdigest() diff --git a/test/registered/rl/test_update_weights_from_tensor.py b/test/registered/rl/test_update_weights_from_tensor.py index 9000207629b5..3d6322802230 100644 --- a/test/registered/rl/test_update_weights_from_tensor.py +++ b/test/registered/rl/test_update_weights_from_tensor.py @@ -3,6 +3,7 @@ register_cuda_ci(est_time=195, suite="stage-b-test-small-1-gpu") register_amd_ci(est_time=195, suite="stage-b-test-small-1-gpu-amd") +import functools import gc import json import random @@ -15,6 +16,7 @@ import sglang as sgl from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree +from sglang.srt.utils.weight_checksum import compute_weights_checksum from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -24,75 +26,119 @@ popen_launch_server, ) - -def test_update_weights_from_tensor(tp_size): - assert torch.cuda.device_count() >= tp_size, f"At least {tp_size} GPUs are required" - torch.cuda.empty_cache() - - engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, tp_size=tp_size) - - param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 16)] - - _check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110]) - - memory_before = torch.cuda.memory_allocated() - new_tensor = torch.full((16384, 2048), 1.5, device="cuda") - - time_start = time.perf_counter() - engine.update_weights_from_tensor([(x, new_tensor) for x in param_names]) - print(f"Time delta: {time.perf_counter() - time_start:.03f}") - - for param_name in param_names[:3]: - _check_param(engine, param_name, [1.5] * 5) - - engine.shutdown() - - del new_tensor - gc.collect() - torch.cuda.ipc_collect() - torch.cuda.empty_cache() - memory_after = torch.cuda.memory_allocated() - assert ( - memory_after <= memory_before + 1024 - ), f"Memory leak detected: {memory_after - memory_before} bytes" +# Llama stacked params: HF splits -> SGLang merged (concat along dim=0) +_STACKED_PARAMS = [ + (".qkv_proj", [".q_proj", ".k_proj", ".v_proj"]), + (".gate_up_proj", [".gate_proj", ".up_proj"]), +] + +_PERTURB_PARAM_NAME = "model.layers.0.mlp.down_proj.weight" +_PERTURB_NUMEL = 128 + + +def _merge_hf_to_sglang(hf_named_params): + """Merge HF-format params to SGLang internal format (TP=1, concat along dim=0).""" + merged = {} + pending = {} + + for name, tensor in hf_named_params: + matched = False + for merged_suffix, shard_suffixes in _STACKED_PARAMS: + for shard_suffix in shard_suffixes: + if shard_suffix in name: + merged_name = name.replace(shard_suffix, merged_suffix) + pending.setdefault(merged_name, {})[shard_suffix] = tensor + matched = True + break + if matched: + break + if not matched: + merged[name] = tensor + + for name, parts in pending.items(): + for merged_suffix, shard_suffixes in _STACKED_PARAMS: + if merged_suffix in name: + merged[name] = torch.cat([parts[s] for s in shard_suffixes], dim=0) + break + + return merged + + +@functools.lru_cache(maxsize=1) +def _load_hf_params(): + """Load and cache HF model params for cross-verification.""" + from transformers import AutoModelForCausalLM + + hf_model = AutoModelForCausalLM.from_pretrained( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, torch_dtype=torch.bfloat16 + ) + params = [(n, p.detach().clone()) for n, p in hf_model.named_parameters()] + del hf_model + return params + + +@functools.lru_cache(maxsize=1) +def _load_perturbed_hf_params(): + """Return HF params with a deterministic perturbation on one tensor.""" + perturbed = [] + found = False + for name, tensor in _load_hf_params(): + cloned = tensor.clone() + if name == _PERTURB_PARAM_NAME: + numel = min(_PERTURB_NUMEL, cloned.numel()) + delta = torch.linspace(0.01, 0.02, steps=numel, dtype=torch.float32).to( + cloned.dtype + ) + cloned.view(-1)[:numel].add_(delta) + found = True + perturbed.append((name, cloned)) + + assert found, f"Cannot find parameter to perturb: {_PERTURB_PARAM_NAME}" + return perturbed + + +@functools.lru_cache(maxsize=1) +def _expected_checksum_after_perturbation(): + """Compute expected checksum from perturbed HF params merged to SGLang format.""" + merged = _merge_hf_to_sglang(_load_perturbed_hf_params()) + return compute_weights_checksum(merged.items()) class TestUpdateWeightsFromTensor(CustomTestCase): def test_update_weights_from_tensor(self): - tp_sizes = [1, 2] - for tp_size in tp_sizes: - if torch.cuda.device_count() < tp_size: - continue + torch.cuda.empty_cache() + memory_before = torch.cuda.memory_allocated() - with self.subTest(tp_size=tp_size): - test_update_weights_from_tensor(tp_size) - - def test_update_weights_from_tensor_load_format_direct(self): engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) - write_param_names = [ - f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16) - ] - read_param_names = [ - f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16) - ] + checksum_before = engine.get_weights_checksum() + hf_params = _load_perturbed_hf_params() + engine.update_weights_from_tensor(list(hf_params)) - _check_param( - engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178] - ) + checksum_after = engine.get_weights_checksum() + assert checksum_after == _expected_checksum_after_perturbation() + assert checksum_after != checksum_before + engine.shutdown() - new_tensor = torch.full((3072, 2048), 1.5) - engine.update_weights_from_tensor( - [ - (write_param_name, new_tensor.clone()) - for write_param_name in write_param_names - ], - load_format="direct", - ) + gc.collect() + torch.cuda.ipc_collect() + torch.cuda.empty_cache() + memory_after = torch.cuda.memory_allocated() + assert ( + memory_after <= memory_before + 1024 + ), f"Memory leak detected: {memory_after - memory_before} bytes" - for read_param_name in read_param_names[:3]: - _check_param(engine, read_param_name, [1.5] * 5) + def test_update_weights_from_tensor_load_format_direct(self): + engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + checksum_before = engine.get_weights_checksum() + # Direct format bypasses merge; send already-merged params + merged = _merge_hf_to_sglang(_load_perturbed_hf_params()) + engine.update_weights_from_tensor(list(merged.items()), load_format="direct") + + checksum_after = engine.get_weights_checksum() + assert checksum_after == _expected_checksum_after_perturbation() + assert checksum_after != checksum_before engine.shutdown() def test_update_weights_from_tensor_load_format_custom(self): @@ -104,82 +150,39 @@ def test_update_weights_from_tensor_load_format_custom(self): custom_weight_loader=[custom_loader_name], ) - write_param_names = [ - f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16) - ] - read_param_names = [ - f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16) - ] - - _check_param( - engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178] - ) - - new_tensor = torch.full((3072, 2048), 1.5) + checksum_before = engine.get_weights_checksum() + merged = _merge_hf_to_sglang(_load_perturbed_hf_params()) engine.update_weights_from_tensor( - [ - (write_param_name, new_tensor.clone()) - for write_param_name in write_param_names - ], - load_format=custom_loader_name, + list(merged.items()), load_format=custom_loader_name ) - for read_param_name in read_param_names[:3]: - _check_param(engine, read_param_name, [1.5] * 5) - + checksum_after = engine.get_weights_checksum() + assert checksum_after == _expected_checksum_after_perturbation() + assert checksum_after != checksum_before engine.shutdown() def test_update_weights_from_tensor_load_format_flattened_bucket(self): - """Test updating weights using flattened_bucket format""" engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) - # Create a small set of parameters for testing - param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 10)] - - # Check original values - _check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110]) - - # Create new tensors with different values - new_tensors = [] - for _, name in enumerate(param_names): - # Create tensors with different values for each parameter - value = 2.0 # Different value for each parameter - new_tensor = torch.full((16384, 2048), value, device="cuda") - new_tensors.append((name, new_tensor)) - - # Create a flattened bucket - flattened_bucket = FlattenedTensorBucket(named_tensors=new_tensors) + checksum_before = engine.get_weights_checksum() + # Flattened bucket calls model.load_weights() internally, so use HF-format names + hf_params = _load_perturbed_hf_params() + named_tensors = [(n, t.cuda()) for n, t in hf_params] - # Extract the flattened tensor and metadata in the format expected by model_runner - flattened_tensor = flattened_bucket.get_flattened_tensor() - metadata = flattened_bucket.get_metadata() + bucket = FlattenedTensorBucket(named_tensors=named_tensors) + bucket_dict = { + "flattened_tensor": bucket.get_flattened_tensor(), + "metadata": bucket.get_metadata(), + } + serialized = MultiprocessingSerializer.serialize(bucket_dict, output_str=True) - # Create the dict format expected by _update_weights_from_flattened_bucket - bucket_dict = {"flattened_tensor": flattened_tensor, "metadata": metadata} - - # Serialize the bucket data - from sglang.srt.utils import MultiprocessingSerializer - - serialized_bucket = MultiprocessingSerializer.serialize( - bucket_dict, output_str=True - ) - - # Create a list where each rank contains the same serialized data - # This simulates the distributed environment where each rank has the same data - serialized_bucket_list = [serialized_bucket] - - # Update weights using flattened_bucket format - time_start = time.perf_counter() engine.update_weights_from_tensor( - named_tensors=serialized_bucket_list, load_format="flattened_bucket" + named_tensors=[serialized], load_format="flattened_bucket" ) - update_time = time.perf_counter() - time_start - print(f"Flattened bucket update time: {update_time:.03f}") - - # Verify the weights were updated correctly - for i, param_name in enumerate(param_names): - _check_param(engine, param_name, [2.0] * 5) + checksum_after = engine.get_weights_checksum() + assert checksum_after == _expected_checksum_after_perturbation() + assert checksum_after != checksum_before engine.shutdown() @@ -289,12 +292,5 @@ def test_update_weights(self): ), f"{actual_values=}" -def _check_param(engine, param_name, expect_values): - actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5] - assert torch.allclose( - actual_values, torch.tensor(expect_values), atol=0.002 - ), f"{actual_values=}" - - if __name__ == "__main__": unittest.main()