Skip to content

Optimize triton_mrope with torch compile#12112

Merged
hnyls2002 merged 5 commits intosgl-project:mainfrom
antgroup:optimize_triton_mrope
Oct 27, 2025
Merged

Optimize triton_mrope with torch compile#12112
hnyls2002 merged 5 commits intosgl-project:mainfrom
antgroup:optimize_triton_mrope

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Oct 25, 2025

Motivation

In the recent Triton mrope PR (#11722), several Torch ops introduced overhead. Despite kernel-level speedups 30%-40%, end-to-end performance regressed.
image

This PR optimizes triton mrope with adding torch compile. The enhancement makes E2E under triton mrope's performance exceeds the legacy version.

The VLM online latency test:

Comparing to the version before #11722
TTFT reduces from 156.52ms to 131.41ms, 16% speedup.
E2E from 8434.92ms to 8272ms, 2% speedup.

Comparing to main:
TTFT reduces from 182.89ms to 131.41ms, 28% speedup.
E2E from 8903.92ms to 8272ms, 7% speedup.

python3 -m unittest test_bench_serving.TestBenchServing.test_vlm_online_latency

Main: Triton mrope kernel
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    1
Max request concurrency:                 not set
Successful requests:                     250
Benchmark duration (s):                  247.85
Total input tokens:                      100908
Total input text tokens:                 21831
Total input vision tokens:               79077
Total generated tokens:                  512000
Total generated tokens (retokenized):    191169
Request throughput (req/s):              1.01
Input token throughput (tok/s):          407.14
Output token throughput (tok/s):         2065.78
Total token throughput (tok/s):          2472.92
Concurrency:                             8.98
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   8903.00
Median E2E Latency (ms):                 8866.82
---------------Time to First Token----------------
Mean TTFT (ms):                          182.89
Median TTFT (ms):                        151.15
P99 TTFT (ms):                           689.05
---------------Inter-Token Latency----------------
Mean ITL (ms):                           4.26
Median ITL (ms):                         3.79
P95 ITL (ms):                            6.45
P99 ITL (ms):                            14.88
Max ITL (ms):                            426.98
==================================================
/usr/lib/python3.10/subprocess.py:1072: ResourceWarning: subprocess 336084 is still running
  _warn("subprocess %s is still running" % self.pid,
ResourceWarning: Enable tracemalloc to get the object allocation traceback
.
----------------------------------------------------------------------
Ran 1 test in 417.063s

OK

Before 11722: Native mrope path
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    1
Max request concurrency:                 not set
Successful requests:                     250
Benchmark duration (s):                  247.37
Total input tokens:                      100908
Total input text tokens:                 21831
Total input vision tokens:               79077
Total generated tokens:                  512000
Total generated tokens (retokenized):    191477
Request throughput (req/s):              1.01
Input token throughput (tok/s):          407.92
Output token throughput (tok/s):         2069.74
Total token throughput (tok/s):          2477.66
Concurrency:                             8.52
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   8434.92
Median E2E Latency (ms):                 8430.82
---------------Time to First Token----------------
Mean TTFT (ms):                          156.52
Median TTFT (ms):                        135.53
P99 TTFT (ms):                           524.26
---------------Inter-Token Latency----------------
Mean ITL (ms):                           4.04
Median ITL (ms):                         3.64
P95 ITL (ms):                            4.02
P99 ITL (ms):                            9.15
Max ITL (ms):                            302.20
==================================================
/usr/lib/python3.10/subprocess.py:1072: ResourceWarning: subprocess 337645 is still running
  _warn("subprocess %s is still running" % self.pid,
ResourceWarning: Enable tracemalloc to get the object allocation traceback
.
----------------------------------------------------------------------
Ran 1 test in 350.831s

OK

This PR: Triton mrope with torch compile
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    1
Max request concurrency:                 not set
Successful requests:                     250
Benchmark duration (s):                  247.31
Total input tokens:                      100908
Total input text tokens:                 21831
Total input vision tokens:               79077
Total generated tokens:                  512000
Total generated tokens (retokenized):    188702
Request throughput (req/s):              1.01
Input token throughput (tok/s):          408.02
Output token throughput (tok/s):         2070.26
Total token throughput (tok/s):          2478.28
Concurrency:                             8.36
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   8272.37
Median E2E Latency (ms):                 8262.43
---------------Time to First Token----------------
Mean TTFT (ms):                          131.91
Median TTFT (ms):                        111.33
P99 TTFT (ms):                           404.65
---------------Inter-Token Latency----------------
Mean ITL (ms):                           3.98
Median ITL (ms):                         3.61
P95 ITL (ms):                            4.13
P99 ITL (ms):                            8.67
Max ITL (ms):                            258.47
==================================================
/usr/lib/python3.10/subprocess.py:1072: ResourceWarning: subprocess 352487 is still running
  _warn("subprocess %s is still running" % self.pid,
ResourceWarning: Enable tracemalloc to get the object allocation traceback
.
----------------------------------------------------------------------
Ran 1 test in 336.764s

OK

Modifications

Accuracy Tests


mmmlu test, accuracy no drop.
➜  sglang git:(main) ✗ python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
Preparing samples...
Loading datasets for 30 subjects...
Loading datasets: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:06<00:00,  4.55it/s]
Saving images to: /root/.cache/mmmu/images
Processing samples...
Processing samples: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 900/900 [00:00<00:00, 211335.44it/s]
Skipping 0 samples with large images, 0.0% of dataset
Samples have been prepared
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 900/900 [03:25<00:00,  4.38it/s]
Benchmark time: 205.43400778598152
answers saved to: ./answer_sglang.json
Evaluating...
answers saved to: ./answer_sglang.json
{'Accounting': {'acc': 0.433, 'num': 30},
 'Agriculture': {'acc': 0.533, 'num': 30},
 'Architecture_and_Engineering': {'acc': 0.467, 'num': 30},
 'Art': {'acc': 0.667, 'num': 30},
 'Art_Theory': {'acc': 0.833, 'num': 30},
 'Basic_Medical_Science': {'acc': 0.6, 'num': 30},
 'Biology': {'acc': 0.367, 'num': 30},
 'Chemistry': {'acc': 0.367, 'num': 30},
 'Clinical_Medicine': {'acc': 0.633, 'num': 30},
 'Computer_Science': {'acc': 0.433, 'num': 30},
 'Design': {'acc': 0.7, 'num': 30},
 'Diagnostics_and_Laboratory_Medicine': {'acc': 0.467, 'num': 30},
 'Economics': {'acc': 0.467, 'num': 30},
 'Electronics': {'acc': 0.267, 'num': 30},
 'Energy_and_Power': {'acc': 0.267, 'num': 30},
 'Finance': {'acc': 0.367, 'num': 30},
 'Geography': {'acc': 0.4, 'num': 30},
 'History': {'acc': 0.667, 'num': 30},
 'Literature': {'acc': 0.767, 'num': 30},
 'Manage': {'acc': 0.333, 'num': 30},
 'Marketing': {'acc': 0.5, 'num': 30},
 'Materials': {'acc': 0.433, 'num': 30},
 'Math': {'acc': 0.567, 'num': 30},
 'Mechanical_Engineering': {'acc': 0.5, 'num': 30},
 'Music': {'acc': 0.467, 'num': 30},
 'Overall': {'acc': 0.513, 'num': 900},
 'Overall-Art and Design': {'acc': 0.667, 'num': 120},
 'Overall-Business': {'acc': 0.42, 'num': 150},
 'Overall-Health and Medicine': {'acc': 0.587, 'num': 150},
 'Overall-Humanities and Social Science': {'acc': 0.667, 'num': 120},
 'Overall-Science': {'acc': 0.427, 'num': 150},
 'Overall-Tech and Engineering': {'acc': 0.414, 'num': 210},
 'Pharmacy': {'acc': 0.6, 'num': 30},
 'Physics': {'acc': 0.433, 'num': 30},
 'Psychology': {'acc': 0.7, 'num': 30},
 'Public_Health': {'acc': 0.633, 'num': 30},
 'Sociology': {'acc': 0.533, 'num': 30}}
eval out saved to ./val_sglang.json
Overall accuracy: 0.513

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the performance of the Triton-based rotary embedding implementation by integrating torch.compile. The primary goal is to mitigate performance bottlenecks caused by certain PyTorch operations within the _forward_triton method. By compiling this critical function, the change significantly improves the Time to First Token (TTFT) and overall end-to-end latency, leading to a more optimized and faster execution path for the system.

Highlights

  • Performance Optimization: Applied torch.compile with dynamic compilation to the _forward_triton method within the rotary embedding layer to reduce overhead from native PyTorch operations.
  • Significant Speedup: Achieved a 16% speedup in Time to First Token (TTFT), reducing it from 156.52ms to 131.41ms, making the Triton mrope path more efficient than the legacy version.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by applying torch.compile to the _forward_triton method in MRotaryEmbedding. The change is well-motivated, and the provided benchmarks clearly demonstrate a significant speedup, particularly in Time to First Token (TTFT). This is a valuable improvement.

My review includes one suggestion for a minor refactoring within the _forward_triton method to improve code clarity and maintainability by removing some now-unreachable code. This will make the function cleaner, especially as it is now a compilation target.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Oct 25, 2025

CC: @mickqian @JustinTong0323

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

Some VL test failed. Investigating, setting WIP.


======================================================================
ERROR: test_video_chat_completion (__main__.TestQwen25VLServer)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/python/sglang/srt/utils/common.py", line 2450, in retry
    return fn()
  File "/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/python/sglang/test/test_utils.py", line 1628, in <lambda>
    lambda: super(CustomTestCase, self)._callTestMethod(method),
AssertionError: video_response: ampire, should contain 'iPod' or 'device'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/python/sglang/test/test_utils.py", line 1627, in _callTestMethod
    retry(
  File "/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/python/sglang/srt/utils/common.py", line 2455, in retry
    raise Exception(f"retry() exceed maximum number of retries.")
Exception: retry() exceed maximum number of retries.

======================================================================
ERROR: test_video_images_chat_completion (__main__.TestQwen25VLServer)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/python/sglang/srt/utils/common.py", line 2450, in retry
    return fn()
  File "/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/python/sglang/test/test_utils.py", line 1628, in <lambda>
    lambda: super(CustomTestCase, self)._callTestMethod(method),
AssertionError: 
        ====================== video_images response =====================
          

        ===========================================================
        should contain 'iPod' or 'device' or 'microphone'
        

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/python/sglang/test/test_utils.py", line 1627, in _callTestMethod
    retry(
  File "/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/python/sglang/srt/utils/common.py", line 2455, in retry
    raise Exception(f"retry() exceed maximum number of retries.")
Exception: retry() exceed maximum number of retries.

----------------------------------------------------------------------
Ran 63 tests in 809.449s

FAILED (errors=5, skipped=6)
------------------------------
Video images response:
The video presents a close-up, static shot of a man, identifiable as Steve Jobs, holding a new electronic device. The scene is set against a dark, indistinct background, which serves to highlight the subject and the product. The man is wearing a black, long-sleeved shirt and round, thin-framed glasses. His face, from the nose down, is visible, and he appears to be speaking, with his mouth slightly open. He holds a white, rectangular device in his right hand, presenting it to the camera. The device is a first-generation iPod Shuffle, characterized by its slim, minimalist design, a small black screen at the top, and a circular control pad below it. The lighting is focused on the man and the device, creating a dramatic and professional presentation atmosphere. The overall composition and the man's demeanor suggest a product launch event. The video is a still image, with no movement or change in the scene across the provided frames.
------------------------------
.
.
End (0/9):
filename='/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/test/srt/test_vision_openai_server_a.py', elapsed=816, estimated_time=918
.
.

Traceback (most recent call last):
  File "/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/test/srt/run_suite.py", line 462, in <module>
    exit_code = run_unittest_files(files, args.timeout_per_file)
  File "/public_sglang_ci/runner-l1c-gpu-23/_work/sglang/sglang/python/sglang/test/test_utils.py", line 762, in run_unittest_files
    ret_code == 0
AssertionError: expected return code 0, but test_vision_openai_server_a.py returned 1

@yuan-luo yuan-luo changed the title Optimize triton_mrope with torch compile [WIP] Optimize triton_mrope with torch compile Oct 26, 2025
@yuan-luo yuan-luo marked this pull request as draft October 26, 2025 02:20
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Oct 26, 2025

After analyzing, these unit test failed with the same issue:

python3 -m unittest test_bench_serving.TestBenchServing.test_vlm_online_latency
python3 -m unittest test_vision_openai_server_a.py  cost 816s
python3 -m unittest test_vlm_input_format.py  cost   170s
python3 -m unittest test_skip_tokenizer_init.py  cost 91s
python3 -m unittest test_w4a8_deepseek_v3.py cost 1011s

The root cause of adding @torch.compile to _forward_triton makes errors is as following:
The Triton kernel in triton_mrope writes back to q_ptr/k_ptr in place (even though there’s a contiguous() outside, it still has “input = output” semantics). When _forward_triton is wrapped with @torch.compile, Dynamo/Inductor can’t model the Triton side effects (it can’t “see” tl.store) and tends to treat it as a pure black box that doesn’t mutate its inputs. This leads to:

  1. Unsound memory planning (buffer reuse/overlap);
  2. Incorrect operator reordering (moving the consumers of q/k before the Triton call);

With CUDA Graphs enabled, the captured graph uses the pointers/schedule from the first run; on replay, if surrounding tensor/storage changes, then get seemingly random corruption.

The fix is to insert an explicit graph break before and after the call (effective during tracing):

torch._dynamo.graph_break()
q, k = triton_mrope(...)
torch._dynamo.graph_break()

Trade-off is the Triton call site won’t be further fused by the compiler, but the rest of the ATen path can still be optimized.
With this fix, the unit test passed correctly, the performance almost keeps the same as the first version fix:
TTFT speedup 16%
E2E speedup 2%

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    1
Max request concurrency:                 not set
Successful requests:                     250
Benchmark duration (s):                  247.31
Total input tokens:                      100908
Total input text tokens:                 21831
Total input vision tokens:               79077
Total generated tokens:                  512000
Total generated tokens (retokenized):    188702
Request throughput (req/s):              1.01
Input token throughput (tok/s):          408.02
Output token throughput (tok/s):         2070.26
Total token throughput (tok/s):          2478.28
Concurrency:                             8.36
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   8272.37
Median E2E Latency (ms):                 8262.43
---------------Time to First Token----------------
Mean TTFT (ms):                          131.91
Median TTFT (ms):                        111.33
P99 TTFT (ms):                           404.65
---------------Inter-Token Latency----------------
Mean ITL (ms):                           3.98
Median ITL (ms):                         3.61
P95 ITL (ms):                            4.13
P99 ITL (ms):                            8.67
Max ITL (ms):                            258.47
==================================================

@yuan-luo yuan-luo force-pushed the optimize_triton_mrope branch from 807015f to dd164c6 Compare October 26, 2025 06:09
@yuan-luo yuan-luo marked this pull request as ready for review October 26, 2025 06:10
@yuan-luo yuan-luo changed the title [WIP] Optimize triton_mrope with torch compile Optimize triton_mrope with torch compile Oct 26, 2025
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Oct 26, 2025

The test script:

$cat bench_local_video.sh
while true; do time curl 'http://127.0.0.1:30000/v1/chat/completions' --header 'Content-Type: application/json' --data '{
        "model": "auto",
        "messages": [
            {
                "role": "user",
                "content": [
                  {"type": "video_url", "video_url": {"url": "/tmp/video_test.mp4"}},
                  {"type": "text", "text": "视频里的招牌写的什么"}
                ]
            }
        ],
                                                  
        "temperature":0.0,
        "max_tokens":1000,
        "stream": false,
        "chat_template_kwargs": {"enable_thinking": false}
    }'; done

Only adding torch compile without the graph break, the result is incorrect:

{"id":"5d8de39855514e3ea67c878cf47df44e","object":"chat.completion","created":1761456840,"model":"auto","choices":[{"index":0,"message":{"role":"assistant","content":" Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик Республик ......
Республик","reasoning_content":null,"tool_calls":null},"logprobs":null,"finish_reason":"length","matched_stop":null}],"usage":{"prompt_tokens":20906,"total_tokens":21906,"completion_tokens":1000,"prompt_tokens_details":{"cached_tokens":20905},"reasoning_tokens":0},"metadata":{"weight_version":"default"}}
real    0m15.103s
user    0m0.002s
sys     0m0.003s

Adding torch compile and graph break, result is correct:

{"id":"04fd3d0397a6494e8515e1d4251f39c2","object":"chat.completion","created":1761457204,"model":"auto","choices":[{"index":0,"message":{"role":"assistant","content":"视频中的招牌上写着“小鞋匠洗鞋”。招牌上还有一些其他文字和图案,但主要的招牌内容是“小鞋匠洗鞋”。","reasoning_content":null,"tool_calls":null},"logprobs":null,"finish_reason":"stop","matched_stop":151645}],"usage":{"prompt_tokens":20906,"total_tokens":20940,"completion_tokens":34,"prompt_tokens_details":{"cached_tokens":20905},"reasoning_tokens":0},"metadata":{"weight_version":"default"}}
real    0m3.438s
user    0m0.001s
sys     0m0.003s
{"id":"867258d1f27b413e9104040ab10ef5d3","object":"chat.completion","created":1761457207,"model":"auto","choices":[{"index":0,"message":{"role":"assistant","content":"视频中的招牌上写着“小鞋匠洗鞋”。招牌上还有一些其他文字和图案,但主要的招牌内容是“小鞋匠洗鞋”。","reasoning_content":null,"tool_calls":null},"logprobs":null,"finish_reason":"stop","matched_stop":151645}],"usage":{"prompt_tokens":20906,"total_tokens":20940,"completion_tokens":34,"prompt_tokens_details":{"cached_tokens":20905},"reasoning_tokens":0},"metadata":{"weight_version":"default"}}
real    0m3.389s
user    0m0.000s
sys     0m0.004s

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Oct 27, 2025

The VLLM Dependency Test failed due to #12117 merged.
Investigating the combination effects of this PR and 12117.
[Updated] This error exists in main branch. Not related with 12112.

Traceback (most recent call last):
  File "/public_sglang_ci/runner-l3b-gpu-0/_work/sglang/sglang/python/sglang/srt/utils/common.py", line 2447, in retry
    return fn()
  File "/public_sglang_ci/runner-l3b-gpu-0/_work/sglang/sglang/python/sglang/test/test_utils.py", line 1628, in <lambda>
    lambda: super(CustomTestCase, self)._callTestMethod(method),
  File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/public_sglang_ci/runner-l3b-gpu-0/_work/sglang/sglang/test/srt/quant/test_awq.py", line 76, in test_mmlu
    self.assertGreater(metrics["score"], 0.88)
  File "/usr/lib/python3.10/unittest/case.py", line 1244, in assertGreater
    self.fail(self._formatMessage(msg, standardMsg))
  File "/usr/lib/python3.10/unittest/case.py", line 675, in fail
    raise self.failureException(msg)
AssertionError: 0.875 not greater than 0.88
E
Writing report to /tmp/mmlu_QuantTrio_Qwen3-VL-30B-A3B-Instruct-AWQ.html
{'other': 0.875, 'other:std': 0.33071891388307384, 'score:std': 0.33071891388307384, 'stem': 0.9090909090909091, 'stem:std': 0.28747978728803447, 'humanities': 0.8695652173913043, 'humanities:std': 0.33678116053977536, 'social_sciences': 0.8571428571428571, 'social_sciences:std': 0.3499271061118826, 'score': 0.875}
Writing results to /tmp/mmlu_QuantTrio_Qwen3-VL-30B-A3B-Instruct-AWQ.json
Total latency: 16.924 s
Score: 0.875

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Oct 27, 2025

I verified that in main branch, this issue still exists.

[root  /root/luoyuan.luo/workspace/sglang_dev] 一 10月 27 19:36:44 
$python ./test/srt/quant/test_awq.py
/opt/conda/lib/python3.10/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
command=python3 -m sglang.launch_server --model-path hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4 --trust-remote-code --device cuda --host 127.0.0.1 --port 21000
/opt/conda/lib/python3.10/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
[2025-10-27 19:36:57] INFO awq.py:275: The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-10-27 19:36:57] WARNING server_args.py:1104: Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-10-27 19:36:57] INFO trace.py:48: opentelemetry package is not installed, tracing disabled

......

Writing report to /tmp/mmlu_QuantTrio_Qwen3-VL-30B-A3B-Instruct-AWQ.html
{'other': np.float64(0.875), 'other:std': np.float64(0.33071891388307384), 'score:std': np.float64(0.3476343040826092), 'stem': np.float64(0.9090909090909091), 'stem:std': np.float64(0.28747978728803447), 'humanities': np.float64(0.8695652173913043), 'humanities:std': np.float64(0.33678116053977536), 'social_sciences': np.float64(0.7857142857142857), 'social_sciences:std': np.float64(0.41032590332414487), 'score': np.float64(0.859375)}
Writing results to /tmp/mmlu_QuantTrio_Qwen3-VL-30B-A3B-Instruct-AWQ.json
Total latency: 9.664 s
Score: 0.859
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/utils/common.py", line 2447, in retry
    return fn()
  File "/opt/conda/lib/python3.10/site-packages/sglang/test/test_utils.py", line 1628, in <lambda>
    lambda: super(CustomTestCase, self)._callTestMethod(method),
  File "/opt/conda/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/root/luoyuan.luo/workspace/sglang_dev/./test/srt/quant/test_awq.py", line 76, in test_mmlu
    self.assertGreater(metrics["score"], 0.88)
  File "/opt/conda/lib/python3.10/unittest/case.py", line 1244, in assertGreater
    self.fail(self._formatMessage(msg, standardMsg))
  File "/opt/conda/lib/python3.10/unittest/case.py", line 675, in fail
    raise self.failureException(msg)
AssertionError: np.float64(0.859375) not greater than 0.88
E
======================================================================
ERROR: test_mmlu (__main__.TestAWQMarlinBfloat16)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/utils/common.py", line 2447, in retry
    return fn()
  File "/opt/conda/lib/python3.10/site-packages/sglang/test/test_utils.py", line 1628, in <lambda>
    lambda: super(CustomTestCase, self)._callTestMethod(method),
AssertionError: np.float64(0.859375) not greater than 0.88

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sglang/test/test_utils.py", line 1627, in _callTestMethod
    retry(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/utils/common.py", line 2452, in retry
    raise Exception(f"retry() exceed maximum number of retries.")
Exception: retry() exceed maximum number of retries.

----------------------------------------------------------------------
Ran 2 tests in 169.780s

FAILED (errors=1)

@hnyls2002 hnyls2002 enabled auto-merge (squash) October 27, 2025 15:49
@hnyls2002 hnyls2002 disabled auto-merge October 27, 2025 15:49
@hnyls2002 hnyls2002 merged commit f389f01 into sgl-project:main Oct 27, 2025
32 of 70 checks passed
@yuan-luo yuan-luo deleted the optimize_triton_mrope branch November 2, 2025 12:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants