Skip to content

[AMD] Support Qwen3-Coder-Next on AMD platform#18355

Merged
HaiShaw merged 6 commits intosgl-project:mainfrom
yichiche:enable-qwen3-coder-next-hip
Feb 25, 2026
Merged

[AMD] Support Qwen3-Coder-Next on AMD platform#18355
HaiShaw merged 6 commits intosgl-project:mainfrom
yichiche:enable-qwen3-coder-next-hip

Conversation

@yichiche
Copy link
Copy Markdown
Collaborator

@yichiche yichiche commented Feb 6, 2026

Motivation

Enable Qwen3-Coder-Next model on AMD GPU platform. With this PR, we are able to support non-MTP (fp8 kv cache) and MTP on Qwen3-Coder-Next.

Modifications

  • aiter_backend.py:
    • Handle v_head_dim correctly for MLA and hybrid linear models. Previously, v_head_dim was retrieved directly from token_to_kv_pool.get_value_buffer(0), which fails for models where layer 0 may not be a full attention layer. Now properly handles MLA models (using model config), hybrid linear models (using get_v_head_dim()), and standard models.
    • Enable MTP with triton backend, will support aiter MTP for non-mla decode kernel in the future
  • qwen3_next.py: Disable dual-stream on AMD platform
  • hybrid_linear_attn_backend.py: Make CuTe DSL GDN import optional and raise an explicit error only when CuTe DSL decode is enabled without required dependency.

Accuracy Tests

MODEL="/data/Qwen/Qwen3-Coder-Next/"
python3 -m sglang.launch_server \
 --model-path "${MODEL}" \
 --tensor-parallel-size 8 \
 --trust-remote-code \
 --chunked-prefill-size 131072 \
 --host 0.0.0.0 \
 --port 9000 \
 --log-requests \
 --disable-radix-cache \
 --mem-fraction-static 0.8 \
 --attention-backend aiter 

Accuracy: 0.944
Invalid: 0.000
Latency: 55.824 s
Output throughput: 3066.797 token/s

Benchmarking and Profiling

Env: MI355 * 8

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 1         
Successful requests:                     8         
Benchmark duration (s):                  23.50     
Total input tokens:                      560000    
Total input text tokens:                 560000    
Total generated tokens:                  1600      
Total generated tokens (retokenized):    1600      
Request throughput (req/s):              0.34      
Input token throughput (tok/s):          23834.71  
Output token throughput (tok/s):         68.10     
Peak output token throughput (tok/s):    95.00     
Peak concurrent requests:                2         
Total token throughput (tok/s):          23902.81  
Concurrency:                             1.00      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   2934.06   
Median E2E Latency (ms):                 2929.72   
P90 E2E Latency (ms):                    2972.40   
P99 E2E Latency (ms):                    2974.58   
---------------Time to First Token----------------
Mean TTFT (ms):                          842.12    
Median TTFT (ms):                        838.34    
P99 TTFT (ms):                           882.41    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.51     
Median TPOT (ms):                        10.51     
P99 TPOT (ms):                           10.52     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           10.51     
Median ITL (ms):                         10.51     
P95 ITL (ms):                            10.58     
P99 ITL (ms):                            10.74     
Max ITL (ms):                            11.94     
==================================================

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yichiche, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces critical updates to enable and optimize the Qwen3-Coder-Next model's performance and compatibility on AMD platforms. The changes focus on improving the robustness of attention mechanism configurations by correctly handling model-specific parameters and adapting dual-stream behavior for AMD GPUs. Additionally, it enhances dependency management for optional performance optimizations, leading to a more stable and compatible experience for users deploying Qwen3-Coder-Next on AMD hardware.

Highlights

  • Enhanced v_head_dim Calculation: The logic for determining v_head_dim in aiter_backend.py has been refined to correctly handle MLA (Multi-Layer Attention) and hybrid linear models, preventing issues where layer 0 might not be a full attention layer.
  • AMD Platform Dual-Stream Disablement: Dual-stream functionality has been explicitly disabled for AMD platforms in qwen3_next.py to ensure compatibility and stability.
  • Optional CuTe DSL GDN Import: The import of CuTe DSL GDN in hybrid_linear_attn_backend.py is now optional, with an explicit ImportError raised only when SGLANG_USE_CUTEDSL_GDN_DECODE is enabled but the required cutlass dependency is missing.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/aiter_backend.py
    • Removed the direct initialization of self.v_head_dim from token_to_kv_pool.get_value_buffer(0).shape[-1].
    • Implemented conditional logic to set self.v_head_dim based on the model type: using model_runner.model_config.v_head_dim for MLA models, model_runner.token_to_kv_pool.get_v_head_dim() for hybrid linear models, and falling back to the original method for standard models.
  • python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
    • Initialized cutedsl_fused_sigmoid_gating_delta_rule_update to None and introduced _cutedsl_import_error to track potential import failures.
    • Wrapped the import of cutedsl_fused_sigmoid_gating_delta_rule_update in a try-except block, making it an optional dependency.
    • Added an ImportError check within the __init__ method to explicitly notify users if SGLANG_USE_CUTEDSL_GDN_DECODE is enabled but the cutlass dependency is not found.
  • python/sglang/srt/models/qwen3_next.py
    • Reordered the conditions within the if statement in the _forward_input_proj method, moving seq_len < DUAL_STREAM_TOKEN_THRESHOLD to the end. This change, as described in the PR, effectively disables dual-stream functionality on AMD platforms.
Activity
  • The pull request was created by yichiche.
  • No human activity (comments, reviews, or progress updates) has been recorded yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Qwen3-Coder-Next on the AMD platform. The changes are well-structured and improve the codebase's robustness and modularity. Key modifications include a more reliable method for determining v_head_dim in aiter_backend.py to accommodate various model architectures, and making the CuTe DSL dependency optional in hybrid_linear_attn_backend.py with clear error handling. I have one suggestion to make the exception handling more specific, which will improve maintainability.

@yichiche yichiche marked this pull request as draft February 6, 2026 08:31
@yichiche yichiche force-pushed the enable-qwen3-coder-next-hip branch from 1cf0da3 to 6646948 Compare February 10, 2026 03:05
@yichiche yichiche marked this pull request as ready for review February 10, 2026 03:14
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@1am9trash
Copy link
Copy Markdown
Collaborator

I see the test uses --chunked-prefill-size 32768 (typically use 131072). Does Qwen3-coder-next tend to perform better with chunk prefill?

@yichiche
Copy link
Copy Markdown
Collaborator Author

I see the test uses --chunked-prefill-size 32768 (typically use 131072). Does Qwen3-coder-next tend to perform better with chunk prefill?

With the change from --chunked-prefill-size 32768 to 131072, we see TTFT improvement from 887.24 to 838.34 (6% uplift).

@HaiShaw HaiShaw self-assigned this Feb 12, 2026
- aiter_backend.py: Handle v_head_dim correctly for MLA and hybrid
  linear models. Previously, v_head_dim was retrieved directly from
  token_to_kv_pool.get_value_buffer(0), which fails for models where
  layer 0 may not be a full attention layer. Now properly handles
  MLA models (using model config), hybrid linear models (using
  get_v_head_dim()), and standard models.

- qwen3_next.py: Use is_cuda_alike() instead of is_cuda() to enable
  CUDA stream creation on both NVIDIA CUDA and AMD ROCm/HIP devices.
@yichiche yichiche force-pushed the enable-qwen3-coder-next-hip branch from 482a922 to 166541f Compare February 23, 2026 04:56
@yichiche yichiche requested a review from HaiShaw as a code owner February 23, 2026 04:56
@yichiche
Copy link
Copy Markdown
Collaborator Author

Solve conflict and rebase again.

@HaiShaw HaiShaw merged commit b2c46fc into sgl-project:main Feb 25, 2026
87 of 102 checks passed
@hubertlu-tw
Copy link
Copy Markdown
Collaborator

@yichiche do we currently have test coverage for this model or this model arch in our CI?

klhhhhh pushed a commit to klhhhhh/sglang that referenced this pull request Feb 26, 2026
Co-authored-by: yichiche@amd.com <jacky.cheng>
@yichiche
Copy link
Copy Markdown
Collaborator Author

yichiche commented Feb 26, 2026

@yichiche do we currently have test coverage for this model or this model arch in our CI?

@hubertlu-tw yes, this is in another PR: #18608

magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Co-authored-by: yichiche@amd.com <jacky.cheng>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Co-authored-by: yichiche@amd.com <jacky.cheng>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants