[diffusion] Add update_weights_from_tensor checker#21106
[diffusion] Add update_weights_from_tensor checker#21106MikukuOvO wants to merge 8 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a critical verification mechanism for diffusion models, specifically targeting the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable feature for verifying tensor updates in diffusion models by adding a weight checker. The implementation is comprehensive, including new API endpoints, scheduler and worker integration, and the core checker logic with support for tensor parallelism. The addition of both unit and end-to-end tests is commendable and ensures the feature's robustness. My review focuses on improving maintainability and readability by addressing code duplication and suggesting minor optimizations. Overall, this is a solid contribution.
| try: | ||
| response = await async_scheduler_client.forward(req) | ||
| except Exception as e: | ||
| return ORJSONResponse( | ||
| {"success": False, "message": str(e)}, | ||
| status_code=500, | ||
| ) | ||
|
|
||
| result = response.output | ||
| success = result.get("success", False) | ||
| message = result.get("message", "Unknown status") | ||
| return ORJSONResponse( | ||
| {"success": success, "message": message}, | ||
| status_code=200 if success else 400, | ||
| ) |
There was a problem hiding this comment.
There is significant code duplication in handling scheduler requests and responses. The logic within this try...except block and the subsequent response parsing is nearly identical in update_weights_from_disk, update_weights_from_tensor, and update_weights_from_tensor_checker. To improve maintainability and reduce redundancy, consider extracting this common logic into a helper function.
For example, you could create a helper like this:
async def _forward_and_respond(req: Any):
try:
response = await async_scheduler_client.forward(req)
except Exception as e:
return ORJSONResponse(
{"success": False, "message": str(e)},
status_code=500,
)
result = response.output
success = result.get("success", False)
message = result.get("message", "Unknown status")
return ORJSONResponse(
{"success": success, "message": message},
status_code=200 if success else 400,
)Each endpoint could then call this helper after creating the specific request object, simplifying the endpoint logic.
| converted_metadata: list[FlattenedTensorMetadata] = [] | ||
| for meta in metadata: | ||
| converted_meta = FlattenedTensorMetadata( | ||
| name=meta.name, | ||
| shape=torch.Size(meta.shape), | ||
| dtype=self._normalize_torch_dtype(meta.dtype), | ||
| start_idx=int(meta.start_idx), | ||
| end_idx=int(meta.end_idx), | ||
| numel=int(meta.numel), | ||
| ) | ||
| converted_metadata.append(converted_meta) |
There was a problem hiding this comment.
For improved readability and conciseness, this loop for building converted_metadata can be refactored into a list comprehension. This is a more idiomatic Python approach for transforming lists.
| converted_metadata: list[FlattenedTensorMetadata] = [] | |
| for meta in metadata: | |
| converted_meta = FlattenedTensorMetadata( | |
| name=meta.name, | |
| shape=torch.Size(meta.shape), | |
| dtype=self._normalize_torch_dtype(meta.dtype), | |
| start_idx=int(meta.start_idx), | |
| end_idx=int(meta.end_idx), | |
| numel=int(meta.numel), | |
| ) | |
| converted_metadata.append(converted_meta) | |
| converted_metadata = [ | |
| FlattenedTensorMetadata( | |
| name=meta.name, | |
| shape=torch.Size(meta.shape), | |
| dtype=self._normalize_torch_dtype(meta.dtype), | |
| start_idx=int(meta.start_idx), | |
| end_idx=int(meta.end_idx), | |
| numel=int(meta.numel), | |
| ) | |
| for meta in metadata | |
| ] |
| missing_names = sorted( | ||
| name | ||
| for name in expected_transformer_sha256 | ||
| if name not in actual_transformer_sha256 | ||
| ) | ||
| mismatched_names = sorted( | ||
| name | ||
| for name, expected_sha256 in expected_transformer_sha256.items() | ||
| if name in actual_transformer_sha256 | ||
| and actual_transformer_sha256[name] != expected_sha256 | ||
| ) |
There was a problem hiding this comment.
The current implementation iterates over expected_transformer_sha256 twice to find missing and mismatched tensor names. This can be optimized by using a single loop, which would be more efficient and arguably more readable.
missing_names = []
mismatched_names = []
for name, expected_sha256 in expected_transformer_sha256.items():
actual_sha256_val = actual_transformer_sha256.get(name)
if actual_sha256_val is None:
missing_names.append(name)
elif actual_sha256_val != expected_sha256:
mismatched_names.append(name)
missing_names.sort()
mismatched_names.sort()…ht-from-tensor-checker # Conflicts: # python/sglang/multimodal_gen/runtime/managers/gpu_worker.py
|
I convert this PR to draft as it depends on other PR, please click |
Motivation
Depends on #20464.
This PR adds a diffusion-side checker for the
update_weights_from_tensorworkflow.The goal is to verify that DiT (
transformer) tensors are updated correctly after client-to-server tensor transfer, and to catch cases where tensor contents or mapping drift during the update path.Modifications
UpdateWeightFromTensorCheckerReqInputUpdateWeightFromTensorCheckerReqOutputPOST /update_weights_from_tensor_checkerUpdateWeightFromTensorCheckerin diffusion runtime utils.transformermodule,update_weights_from_tensor -> checker,Accuracy Tests
This PR does not change model forward behavior or inference math.
Tests run:
python test/test_update_weight_from_tensor_checker.pypython test/test_update_weights_from_tensor_checker_e2e.pyBenchmarking and Profiling
This PR adds a debug/verification path only.
No speed benchmarking was performed.
Checklist