Skip to content

[diffusion] Add update_weights_from_tensor checker#21106

Draft
MikukuOvO wants to merge 8 commits intosgl-project:mainfrom
MikukuOvO:fenglin/update-weight-from-tensor-checker
Draft

[diffusion] Add update_weights_from_tensor checker#21106
MikukuOvO wants to merge 8 commits intosgl-project:mainfrom
MikukuOvO:fenglin/update-weight-from-tensor-checker

Conversation

@MikukuOvO
Copy link
Copy Markdown
Contributor

Motivation

Depends on #20464.

This PR adds a diffusion-side checker for the update_weights_from_tensor workflow.
The goal is to verify that DiT (transformer) tensors are updated correctly after client-to-server tensor transfer, and to catch cases where tensor contents or mapping drift during the update path.

Modifications

  • Added a new post-training checker request/response:
    • UpdateWeightFromTensorCheckerReqInput
    • UpdateWeightFromTensorCheckerReqOutput
  • Added a new endpoint:
    • POST /update_weights_from_tensor_checker
  • Added scheduler and worker support for the new checker request.
  • Added UpdateWeightFromTensorChecker in diffusion runtime utils.
  • The checker:
    • only verifies the live transformer module,
    • compares per-tensor SHA-256 manifests,
    • supports both regular tensors and DTensor local shards,
    • aggregates checker results across TP ranks.
  • Added:
    • a lightweight unit test for checker logic,
    • a black-box e2e test covering:
      • successful update_weights_from_tensor -> checker,
      • failure when the expected SHA manifest does not match the updated weights.

Accuracy Tests

This PR does not change model forward behavior or inference math.

Tests run:

  • python test/test_update_weight_from_tensor_checker.py
  • python test/test_update_weights_from_tensor_checker_e2e.py

Benchmarking and Profiling

This PR adds a debug/verification path only.
No speed benchmarking was performed.

Checklist

@github-actions github-actions bot added documentation Improvements or additions to documentation diffusion SGLang Diffusion labels Mar 21, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a critical verification mechanism for diffusion models, specifically targeting the update_weights_from_tensor workflow. The primary goal is to ensure the integrity and correctness of transformer module weights after they have been transferred from a client to the server. By implementing a new checker that compares SHA-256 manifests of live tensors against expected values, the system can now proactively detect any data corruption or mapping discrepancies that might occur during the update process, thereby enhancing the reliability of dynamic weight updates.

Highlights

  • New Weight Verification Endpoint: Introduced a new API endpoint (/update_weights_from_tensor_checker) and associated data structures to verify diffusion model weights after client-to-server tensor transfers.
  • Robust Tensor Checker Implementation: Developed UpdateWeightFromTensorChecker to compare live transformer module weights against expected SHA-256 manifests, supporting both regular and distributed (DTensor) local shards, and aggregating results across Tensor Parallel (TP) ranks.
  • Integrated Workflow Support: Added scheduler and worker support for the new checker request, ensuring proper handling and execution within the diffusion runtime.
  • Comprehensive Testing: Included dedicated unit tests for the checker logic and black-box end-to-end tests to validate the entire workflow, including successful updates and failure detection for mismatched weights.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable feature for verifying tensor updates in diffusion models by adding a weight checker. The implementation is comprehensive, including new API endpoints, scheduler and worker integration, and the core checker logic with support for tensor parallelism. The addition of both unit and end-to-end tests is commendable and ensures the feature's robustness. My review focuses on improving maintainability and readability by addressing code duplication and suggesting minor optimizations. Overall, this is a solid contribution.

Comment on lines +69 to +83
try:
response = await async_scheduler_client.forward(req)
except Exception as e:
return ORJSONResponse(
{"success": False, "message": str(e)},
status_code=500,
)

result = response.output
success = result.get("success", False)
message = result.get("message", "Unknown status")
return ORJSONResponse(
{"success": success, "message": message},
status_code=200 if success else 400,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication in handling scheduler requests and responses. The logic within this try...except block and the subsequent response parsing is nearly identical in update_weights_from_disk, update_weights_from_tensor, and update_weights_from_tensor_checker. To improve maintainability and reduce redundancy, consider extracting this common logic into a helper function.

For example, you could create a helper like this:

async def _forward_and_respond(req: Any):
    try:
        response = await async_scheduler_client.forward(req)
    except Exception as e:
        return ORJSONResponse(
            {"success": False, "message": str(e)},
            status_code=500,
        )

    result = response.output
    success = result.get("success", False)
    message = result.get("message", "Unknown status")
    return ORJSONResponse(
        {"success": success, "message": message},
        status_code=200 if success else 400,
    )

Each endpoint could then call this helper after creating the specific request object, simplifying the endpoint logic.

Comment on lines +426 to +436
converted_metadata: list[FlattenedTensorMetadata] = []
for meta in metadata:
converted_meta = FlattenedTensorMetadata(
name=meta.name,
shape=torch.Size(meta.shape),
dtype=self._normalize_torch_dtype(meta.dtype),
start_idx=int(meta.start_idx),
end_idx=int(meta.end_idx),
numel=int(meta.numel),
)
converted_metadata.append(converted_meta)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and conciseness, this loop for building converted_metadata can be refactored into a list comprehension. This is a more idiomatic Python approach for transforming lists.

Suggested change
converted_metadata: list[FlattenedTensorMetadata] = []
for meta in metadata:
converted_meta = FlattenedTensorMetadata(
name=meta.name,
shape=torch.Size(meta.shape),
dtype=self._normalize_torch_dtype(meta.dtype),
start_idx=int(meta.start_idx),
end_idx=int(meta.end_idx),
numel=int(meta.numel),
)
converted_metadata.append(converted_meta)
converted_metadata = [
FlattenedTensorMetadata(
name=meta.name,
shape=torch.Size(meta.shape),
dtype=self._normalize_torch_dtype(meta.dtype),
start_idx=int(meta.start_idx),
end_idx=int(meta.end_idx),
numel=int(meta.numel),
)
for meta in metadata
]

Comment on lines +127 to +137
missing_names = sorted(
name
for name in expected_transformer_sha256
if name not in actual_transformer_sha256
)
mismatched_names = sorted(
name
for name, expected_sha256 in expected_transformer_sha256.items()
if name in actual_transformer_sha256
and actual_transformer_sha256[name] != expected_sha256
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation iterates over expected_transformer_sha256 twice to find missing and mismatched tensor names. This can be optimized by using a single loop, which would be more efficient and arguably more readable.

        missing_names = []
        mismatched_names = []
        for name, expected_sha256 in expected_transformer_sha256.items():
            actual_sha256_val = actual_transformer_sha256.get(name)
            if actual_sha256_val is None:
                missing_names.append(name)
            elif actual_sha256_val != expected_sha256:
                mismatched_names.append(name)
        missing_names.sort()
        mismatched_names.sort()

@ping1jing2 ping1jing2 marked this pull request as draft March 30, 2026 05:43
@ping1jing2
Copy link
Copy Markdown
Collaborator

I convert this PR to draft as it depends on other PR, please click ready for review after everything is ready

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants