fix: cuda graph issue while running longcat_flash#14007
fix: cuda graph issue while running longcat_flash#14007zhyncs merged 8 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @tianhaoz95, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly fixes a CUDA graph issue in longcat_flash by ensuring the qkv_latent_func is passed during the initialization of LayerCommunicator. This prevents an AssertionError that occurred because attention inputs were not being set. The fix is straightforward, consistent with its usage in other models, and is well-supported by the provided test results. The changes look good to me.
|
/tag-and-rerun-ci |
Motivation
Currently when we run
longcat_flashit will error out with:this is because in #10568
set_attn_inputs()is skipped whenqkv_latent_func == None, and laterfetch_qkv_latent()will assert on this and fail.Modifications
Make sure
qkv_latent_funcis set.Accuracy Tests
After the change the following launch command runs and output normal tokens:
python3 -m sglang.launch_server \ --trust-remote-code \ --model $MODEL_PATH \ --tp 8 \ --ep-size 8 \ --skip-server-warmup \ --cuda-graph-bs 1 2 3 4 5 6 7 8 \ --host 0.0.0.0 \ --port 8080gsm8k:
Benchmarking and Profiling
gsm8k:
same as when i revert to before it was broken:
Checklist