[Ckpt Engine] feat: new sglang entrypoint support for update#12216
Conversation
Summary of ChangesHello @stmatengss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the SGLang checkpoint engine by providing a user-friendly, integrated command-line interface for updating model weights. It abstracts away the complexities of distributed execution by automatically invoking Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new entrypoint for updating model weights via the checkpoint engine, which is a useful addition for ease of use. The implementation is well-structured. I've identified a couple of areas for improvement: one regarding the robustness of command-line argument parsing before launching torchrun, and another related to a logic bug that causes an unnecessary delay in p2p update mode. My review includes specific code suggestions to address these points.
| if update_method: | ||
| # sleep 2s to wait destroy process group | ||
| time.sleep(2) |
There was a problem hiding this comment.
The time.sleep(2) is executed whenever update_method is "p2p" or "all". However, the comment "sleep 2s to wait destroy process group" suggests this delay is only necessary after the broadcast update method has run. This happens when update_method is "all", but not when it is "p2p". When update_method is "p2p", no broadcast occurs, so the sleep is unnecessary and introduces a performance penalty. The condition should be more specific to only sleep when update_method == "all". Additionally, the if update_method: check is redundant as it's always true within this block.
| if update_method: | |
| # sleep 2s to wait destroy process group | |
| time.sleep(2) | |
| if update_method == "all": | |
| # sleep 2s to wait destroy process group | |
| time.sleep(2) |
| for i, arg in enumerate(args): | ||
| if arg == "--inference-parallel-size" and i + 1 < len(args): | ||
| try: | ||
| inference_parallel_size = int(args[i + 1]) | ||
| except ValueError: | ||
| pass | ||
| break | ||
| elif arg.startswith("--inference-parallel-size="): | ||
| try: | ||
| inference_parallel_size = int(arg.split("=", 1)[1]) | ||
| except ValueError: | ||
| pass | ||
| break |
There was a problem hiding this comment.
The manual parsing of --inference-parallel-size uses try-except pass, which silently ignores invalid values. If a user provides a non-integer value, the script proceeds with the default value of 8 for nproc-per-node, only to fail later during the more robust argument parsing in main(). This behavior can be confusing. It's better to fail early with a clear error. Removing the try-except blocks will allow the ValueError from int() to propagate and terminate the script, which is more robust and user-friendly.
| for i, arg in enumerate(args): | |
| if arg == "--inference-parallel-size" and i + 1 < len(args): | |
| try: | |
| inference_parallel_size = int(args[i + 1]) | |
| except ValueError: | |
| pass | |
| break | |
| elif arg.startswith("--inference-parallel-size="): | |
| try: | |
| inference_parallel_size = int(arg.split("=", 1)[1]) | |
| except ValueError: | |
| pass | |
| break | |
| for i, arg in enumerate(args): | |
| if arg == "--inference-parallel-size" and i + 1 < len(args): | |
| inference_parallel_size = int(args[i + 1]) | |
| break | |
| elif arg.startswith("--inference-parallel-size="): | |
| inference_parallel_size = int(arg.split("=", 1)[1]) | |
| break |
| # Build torchrun command | ||
| cmd = ["torchrun", f"--nproc-per-node={inference_parallel_size}", __file__] + args | ||
|
|
||
| print(f"Running: {' '.join(cmd)}", file=sys.stderr) | ||
|
|
||
| # Execute torchrun with the original script | ||
| try: | ||
| result = subprocess.run(cmd, check=False) | ||
| sys.exit(result.returncode) | ||
| except FileNotFoundError: | ||
| print( | ||
| "Error: torchrun command not found. Please ensure PyTorch is installed.", | ||
| file=sys.stderr, | ||
| ) | ||
| sys.exit(1) | ||
| except KeyboardInterrupt: | ||
| print("\nInterrupted by user", file=sys.stderr) | ||
| sys.exit(130) |
There was a problem hiding this comment.
Is torchrun compulsory?
There was a problem hiding this comment.
If we use ParameterServer in the checkpoint engine, torchrun is compulsory.
Motivation
Add a new entrypoint for ease of use, as suggested in #11755 (comment).
With this PR, it can support both torchrun and sglang entrypoint.
Usage
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist