Skip to content
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Request/response data structures for post-training APIs."""

from dataclasses import dataclass
from typing import Optional, Union


@dataclass
Expand All @@ -12,6 +13,39 @@ class UpdateWeightFromDiskReqInput:
target_modules: list[str] | None = None


@dataclass
class UpdateWeightFromTensorReqInput:
"""Request to update model weights from tensor payloads for diffusion models."""

serialized_named_tensors: list[Union[str, bytes]]
load_format: Optional[str] = None
target_modules: list[str] | None = None
weight_version: Optional[str] = None


@dataclass
class UpdateWeightFromTensorReqOutput:
"""Response for update_weights_from_tensor request."""

success: bool
message: str


@dataclass
class UpdateWeightFromTensorCheckerReqInput:
"""Request to verify live transformer weights against expected SHA-256 values."""

expected_transformer_sha256: list[dict[str, str]]


@dataclass
class UpdateWeightFromTensorCheckerReqOutput:
"""Response for update_weights_from_tensor_checker request."""

success: bool
message: str


@dataclass
class GetWeightsChecksumReqInput:
"""Compute SHA-256 checksum of loaded module weights for verification."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import (
GetWeightsChecksumReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromTensorCheckerReqInput,
UpdateWeightFromTensorReqInput,
)
from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client

Expand Down Expand Up @@ -46,6 +48,76 @@ async def update_weights_from_disk(request: Request):
)


@router.post("/update_weights_from_tensor")
async def update_weights_from_tensor(request: Request):
"""Update model weights from serialized tensor payloads."""
body = await request.json()
serialized_named_tensors = body.get("serialized_named_tensors")
if not serialized_named_tensors:
return ORJSONResponse(
{"success": False, "message": "serialized_named_tensors is required"},
status_code=400,
)

req = UpdateWeightFromTensorReqInput(
serialized_named_tensors=serialized_named_tensors,
load_format=body.get("load_format"),
target_modules=body.get("target_modules"),
weight_version=body.get("weight_version"),
)

try:
response = await async_scheduler_client.forward(req)
except Exception as e:
return ORJSONResponse(
{"success": False, "message": str(e)},
status_code=500,
)

result = response.output
success = result.get("success", False)
message = result.get("message", "Unknown status")
return ORJSONResponse(
{"success": success, "message": message},
status_code=200 if success else 400,
)
Comment on lines +69 to +83
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication in handling scheduler requests and responses. The logic within this try...except block and the subsequent response parsing is nearly identical in update_weights_from_disk, update_weights_from_tensor, and update_weights_from_tensor_checker. To improve maintainability and reduce redundancy, consider extracting this common logic into a helper function.

For example, you could create a helper like this:

async def _forward_and_respond(req: Any):
    try:
        response = await async_scheduler_client.forward(req)
    except Exception as e:
        return ORJSONResponse(
            {"success": False, "message": str(e)},
            status_code=500,
        )

    result = response.output
    success = result.get("success", False)
    message = result.get("message", "Unknown status")
    return ORJSONResponse(
        {"success": success, "message": message},
        status_code=200 if success else 400,
    )

Each endpoint could then call this helper after creating the specific request object, simplifying the endpoint logic.



@router.post("/update_weights_from_tensor_checker")
async def update_weights_from_tensor_checker(request: Request):
"""Verify live transformer weights against expected SHA-256 values."""
body = await request.json()
expected_transformer_sha256 = body.get("expected_transformer_sha256")
if (
not isinstance(expected_transformer_sha256, list)
or not expected_transformer_sha256
):
return ORJSONResponse(
{"success": False, "message": "expected_transformer_sha256 is required"},
status_code=400,
)

req = UpdateWeightFromTensorCheckerReqInput(
expected_transformer_sha256=expected_transformer_sha256,
)

try:
response = await async_scheduler_client.forward(req)
except Exception as e:
return ORJSONResponse(
{"success": False, "message": str(e)},
status_code=500,
)

result = response.output
success = result.get("success", False)
message = result.get("message", "Unknown status")
return ORJSONResponse(
{"success": success, "message": message},
status_code=200 if success else 400,
)


@router.post("/get_weights_checksum")
async def get_weights_checksum(request: Request):
"""Return SHA-256 checksum of each requested module's weights."""
Expand Down
160 changes: 160 additions & 0 deletions python/sglang/multimodal_gen/runtime/loader/weights_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import gc
from pathlib import Path
from typing import Any

import torch
from torch.distributed.tensor import DTensor, distribute_tensor
Expand All @@ -57,6 +58,10 @@
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model
from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.srt.weight_sync.tensor_bucket import (
FlattenedTensorBucket,
FlattenedTensorMetadata,
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -291,3 +296,158 @@ def _rollback(self, updated_modules: list[str]) -> None:
continue
weights_iter = _get_weights_iter(str(weights_dir))
_load_weights_into_module(module, weights_iter)

def update_weights_from_tensor(
self,
named_tensors: Any,
load_format: str | None = None,
target_modules: list[str] | None = None,
) -> tuple[bool, str]:
"""Update module weights from in-memory tensor payloads."""
try:
modules_to_update = self._collect_modules(target_modules)
except ValueError as e:
logger.error(str(e))
return False, str(e)

if not modules_to_update:
error_msg = (
f"No matching modules found for update. "
f"Requested: {target_modules}. "
f"Available nn.Module(s): {list(get_updatable_modules(self.pipeline).keys())}"
)
logger.error(error_msg)
return False, error_msg

try:
module_payloads = self._resolve_module_payloads(
named_tensors=named_tensors,
modules_to_update=modules_to_update,
)
except ValueError as e:
logger.error(str(e))
return False, str(e)

updated_modules: list[str] = []
for module_name, module in modules_to_update:
try:
payload = module_payloads[module_name]
weights_iter = self._materialize_weights_iter(payload, load_format)
_load_weights_into_module(module, weights_iter)
updated_modules.append(module_name)
except Exception as e:
rollback_list = updated_modules + [module_name]
logger.error(
f"Tensor weight update failed for module '{module_name}': {e}. "
f"Rolling back {len(rollback_list)} module(s) "
f"(including partially-loaded '{module_name}'): "
f"{rollback_list}.",
exc_info=True,
)
self._rollback(rollback_list)
return False, (
f"Failed to update module '{module_name}' from tensor: {e}. "
f"All modules rolled back to original weights."
)

gc.collect()
torch.cuda.empty_cache()
names = ", ".join(updated_modules)
message = f"Updated {len(updated_modules)} modules from tensor ({names})."
logger.info(message)
return True, message

def _resolve_module_payloads(
self,
named_tensors: Any,
modules_to_update: list[tuple[str, torch.nn.Module]],
) -> dict[str, Any]:
"""Resolve a generic tensor payload to per-module payload mapping."""
module_names = [name for name, _ in modules_to_update]

# Preferred format for multi-module update:
# {
# "transformer": <module_payload>,
# "vae": <module_payload>,
# }
if isinstance(named_tensors, dict):
missing = [name for name in module_names if name not in named_tensors]
if missing:
raise ValueError(
f"Missing tensor payload for module(s): {missing}. "
f"Provided modules: {list(named_tensors.keys())}"
)
return {name: named_tensors[name] for name in module_names}

# Single-module shortcut: allow direct payload when exactly one target module exists.
if len(module_names) == 1:
return {module_names[0]: named_tensors}

raise ValueError(
"Ambiguous tensor payload for multi-module update. "
"Provide a dict mapping module_name -> module payload, "
f"requested modules: {module_names}."
)

def _materialize_weights_iter(self, module_payload: Any, load_format: str | None):
"""Convert one module payload to an iterator of (param_name, tensor)."""
if load_format == "flattened_bucket":
if not isinstance(module_payload, dict):
raise ValueError(
"flattened_bucket payload must be a dict with "
"'flattened_tensor' and 'metadata'."
)
flattened_tensor = module_payload.get("flattened_tensor")
metadata = module_payload.get("metadata")
if flattened_tensor is None or metadata is None:
raise ValueError(
"flattened_bucket payload missing 'flattened_tensor' or 'metadata'."
)
return self._reconstruct_from_flattened_bucket(flattened_tensor, metadata)

# Default/direct format: list/tuple of (name, tensor)
if isinstance(module_payload, (list, tuple)):
return iter(module_payload)

raise ValueError(
f"Unsupported module payload type for load_format={load_format}: "
f"{type(module_payload).__name__}"
)

def _reconstruct_from_flattened_bucket(self, flattened_tensor: Any, metadata: Any):
"""Reconstruct [(name, tensor), ...] from flattened-bucket payload."""
if not isinstance(flattened_tensor, torch.Tensor):
raise ValueError(
"flattened_bucket 'flattened_tensor' must be a torch.Tensor."
)
if not isinstance(metadata, list):
raise ValueError("flattened_bucket 'metadata' must be a list.")

converted_metadata: list[FlattenedTensorMetadata] = []
for meta in metadata:
converted_meta = FlattenedTensorMetadata(
name=meta.name,
shape=torch.Size(meta.shape),
dtype=self._normalize_torch_dtype(meta.dtype),
start_idx=int(meta.start_idx),
end_idx=int(meta.end_idx),
numel=int(meta.numel),
)
converted_metadata.append(converted_meta)
Comment on lines +426 to +436
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and conciseness, this loop for building converted_metadata can be refactored into a list comprehension. This is a more idiomatic Python approach for transforming lists.

Suggested change
converted_metadata: list[FlattenedTensorMetadata] = []
for meta in metadata:
converted_meta = FlattenedTensorMetadata(
name=meta.name,
shape=torch.Size(meta.shape),
dtype=self._normalize_torch_dtype(meta.dtype),
start_idx=int(meta.start_idx),
end_idx=int(meta.end_idx),
numel=int(meta.numel),
)
converted_metadata.append(converted_meta)
converted_metadata = [
FlattenedTensorMetadata(
name=meta.name,
shape=torch.Size(meta.shape),
dtype=self._normalize_torch_dtype(meta.dtype),
start_idx=int(meta.start_idx),
end_idx=int(meta.end_idx),
numel=int(meta.numel),
)
for meta in metadata
]


bucket = FlattenedTensorBucket(
flattened_tensor=flattened_tensor,
metadata=converted_metadata,
)
return bucket.reconstruct_tensors()

def _normalize_torch_dtype(self, dtype: Any) -> torch.dtype:
if isinstance(dtype, torch.dtype):
return dtype
if isinstance(dtype, str):
# Supports "torch.float16" and "float16".
name = dtype.split(".")[-1]
normalized = getattr(torch, name, None)
if isinstance(normalized, torch.dtype):
return normalized
raise ValueError(f"Unsupported dtype in flattened_bucket metadata: {dtype!r}")
Loading
Loading