Skip to content

[NPU]Support GPT-OSS for NPU#14197

Merged
iforgetmyname merged 18 commits intosgl-project:mainfrom
Todobe:gpt-oss
Jan 18, 2026
Merged

[NPU]Support GPT-OSS for NPU#14197
iforgetmyname merged 18 commits intosgl-project:mainfrom
Todobe:gpt-oss

Conversation

@Todobe
Copy link
Copy Markdown
Contributor

@Todobe Todobe commented Dec 1, 2025

Co-author: @McZyWu

Motivation

Adapting GPT-OSS model for NPU.

Modifications

  1. Operators capable of handling sinks and sliding windows have been added to the Ascend backend for attention.
  2. The swiglu function in torch_npu that affects precision has been replaced.
  3. Other incompatible method calls have been adapted for NPU.

Accuracy Tests

We tested the GPT-oss-120b-bf16 model on the GSM8K dataset with 256 questions with max_new_tokens=512. The accuracy was 84.4%.

image

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Todobe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the SGLang framework by enabling the GPT-OSS model to operate effectively on NPU hardware. The changes involve fundamental adjustments to the attention mechanism and activation functions within the Ascend backend, ensuring both compatibility and improved performance. The overall goal is to broaden the hardware support for the GPT-OSS model, allowing it to leverage the specialized capabilities of NPUs.

Highlights

  • NPU Adaptation: The GPT-OSS model has been adapted to run on NPU (Neural Processing Unit) hardware, specifically targeting the Ascend backend.
  • Attention Mechanism Enhancements: New operators capable of handling 'sinks' and 'sliding windows' have been integrated into the Ascend backend for the attention mechanism, improving its functionality on NPU.
  • Precision Fix for SiGLU: The default swiglu function in torch_npu was replaced with a custom implementation to resolve precision issues affecting the GPT-OSS model on NPU.
  • NPU Compatibility: Various other incompatible method calls across the codebase have been identified and adapted to ensure proper functioning on NPU.
  • Accuracy Improvement: Accuracy tests on the GPT-oss-120b-bf16 model on the GSM8K dataset showed 67.2% accuracy on NPU, which is an improvement over the 62.5% observed on GPU for the same test.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Todobe Todobe changed the title Gpt oss Adapt GPT-OSS for NPU Dec 1, 2025
@Todobe Todobe changed the title Adapt GPT-OSS for NPU Support GPT-OSS for NPU Dec 1, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adapts the GPT-OSS model for NPU by adding operators for sinks and sliding windows, replacing a problematic swiglu function, and other NPU-specific adaptations. The changes are generally well-structured. I've identified a critical bug in the activation function logic for NPU MoE layers, which could affect models other than GPT-OSS. I've also suggested a refactoring to reduce code duplication in the attention backend. Addressing these points will improve the correctness and maintainability of the code.

@ping1jing2 ping1jing2 self-assigned this Dec 2, 2025
@ping1jing2 ping1jing2 changed the title Support GPT-OSS for NPU [Ascend]Support GPT-OSS for NPU Dec 2, 2025
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this file

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this file

Comment on lines +514 to +516
if self.moe_runner_config.custom_act_fn is not None:
hidden_states = self.moe_runner_config.custom_act_fn(layer, hidden_states)
elif self.moe_runner_config.activation == "silu":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an option swiglu_oai for activation here

Comment on lines +134 to +140
custom_act_fn = None
if _is_npu:
if self.layer_id == 0:
logger.warning(
"Warning: GPT-OSS use custom activate function on FusedMoE when using ascend backend."
)
custom_act_fn = swiglu_oai
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this part of changes, instead, modify modelconfig when init model

gemm1_clamp_limit=self.gemm1_clamp_limit,
with_bias=True,
prefix=add_prefix("experts", prefix),
custom_act_fn=custom_act_fn,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

),
)
extra_args = {}
if _is_cuda:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if _is_cuda:
if not _is_npu:

@iforgetmyname
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@iforgetmyname iforgetmyname changed the title [Ascend]Support GPT-OSS for NPU [NPU]Support GPT-OSS for NPU Jan 18, 2026
@iforgetmyname iforgetmyname merged commit 733de6b into sgl-project:main Jan 18, 2026
375 of 401 checks passed
DotSlash-A pushed a commit to DotSlash-A/sglang that referenced this pull request Jan 19, 2026
* fix(ci): recover from corrupted MMMU parquet cache (sgl-project#17256)

* [diffusion] feat: support default 4-step inference for Flux2-Klein distilled models (sgl-project#17225)

Signed-off-by: Lancer <maruixiang6688@gmail.com>

* Add runner utilization report workflow (sgl-project#17234)

* cli: support sglang version (sgl-project#17250)

* Use swa radix cache and memory pool for gpt-oss model (sgl-project#17261)

* [VLM][Reland] Refactor load_mm_data to improve performance (sgl-project#16152)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>

* [Tiny] Improve docs (sgl-project#17264)

* [diffusion] fix: set guidance_scale default to None (sgl-project#17182)

* Tiny fix comment typo (sgl-project#17287)

* [SPEC_V2] Enable cudagraph draft_extend for trtllm_mla_backend and Acclen Fix for DP under cudagraph mode (sgl-project#16974)

* Add kl test for swa radix cache (sgl-project#17281)

* fix: Handle multiple named chat templates in HuggingFace tokenizers (sgl-project#17236)

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>

* Move radix cache related tests (sgl-project#17295)

* [Refactor] Add `-fp4-gemm-backend` to replace `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` (sgl-project#16534)

Co-authored-by: Vincent Zhong <207368749+vincentzed@users.noreply.github.com>

* [Bugfix] Fix PD accuracy when MTP is not configured on the prefill node (sgl-project#17212)

Co-authored-by: Shangming Cai <csmthu@gmail.com>

* [Diffusion] Apply jit qk_norm to flux1 (sgl-project#17296)

* [Refactor] Split out deepseek v2 weight loader function into mixin (sgl-project#16649)

* [NPU]Support GPT-OSS for NPU (sgl-project#14197)

* [jit-kernel] Add CuTe DSL GDN Decode Kernel (sgl-project#15631)

Co-authored-by: Jinyan Chen <jinyanc@nvidia.com>

* [GLM 4.7] Add RTX 6000 Pro aka sm120 (sgl-project#17235)

Co-authored-by: root <root@ubuntu-nvidia.localdomain>

* Update CODEOWNERS for multimodal_gen (sgl-project#17308)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>

* [Feature] overlap LoRA weight loading with compute (sgl-project#15512)

* [PD] Optimize MHA models pp util calculation logic (sgl-project#17306)

* [Minor] Correct sglang version when installing from source (sgl-project#17315)

* Use dsv3 optimized routing `fused_topk_deepseek` instead of `moe_fused_gate` (sgl-project#15347)

* [DeepSeek v3.2] Opt MTP decode cuda batch sizes and nsa implementation (sgl-project#16961)

* Update code sync scripts (sgl-project#17319)

* [Auto Sync] Update tokenizer_manager.py (20260119) (sgl-project#17317)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* support new qwen3_coder_detector (sgl-project#16744)

Co-authored-by: liugaoji.lgj <liugaoji.lgj@alibaba-inc.com>

* Fix kernel selection in biased_grouped_topk_gpu (sgl-project#17325)

* KV Cache Events with Attention DP bug fix (sgl-project#16030) (sgl-project#16412)

* [Perf] fuse q, k norm for Flux2Attention (sgl-project#17241)

Co-authored-by: Minglei Zhu <zminglei@linkedin.com>

* [CI] Add partition to stage-b-test-large-1-gpu (11->12) (sgl-project#17245)

* fix(ci): rate limit and permission errors in trace publishing (sgl-project#17238)

* Revert "[Perf] fuse q, k norm for Flux2Attention (sgl-project#17241)" (sgl-project#17332)

* Migrate performance, accuracy, and quantization tests to CI registry (sgl-project#17177)

Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>

* Inclusion of nvfp4 blockscale in EPLB Rebalance (sgl-project#17158)

* [Refactor] Set `fp4-gemm-backend=auto` on SM100 and rename `fp4-gemm-backend` with `flashinfer_` prefix (sgl-project#17309)

* [Diffusion] Apply qknorm to flux2 and apply lightx2v rms_norm_one_pass kernel(without residual) (sgl-project#17305)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Fix v32 continue_final_message not work (sgl-project#16567)

* Evict swa kv cache during decoding (sgl-project#17220)

* [RadixTree][1/N Refactor]: Support unified match_prefix params (sgl-project#17142)

Co-authored-by: yizhang2077 <1109276519@qq.com>
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>

* [AMD CI] Migrate and Add More Testcases (sgl-project#17116)

Co-authored-by: yctseng0211 <yctseng@amd.com>

* [AMD] CI - add partitions for stage-b-test-small-1-gpu-amd (sgl-project#17345)

* Restore deepseek_v2.py to main's code, except the utils

* Ran `pre-commit`

---------

Signed-off-by: Lancer <maruixiang6688@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Hudson Xing <1277646412@qq.com>
Co-authored-by: Lancer <402430575@qq.com>
Co-authored-by: Alison Shao <54658187+alisonshao@users.noreply.github.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: Ke Bao <ispobaoke@gmail.com>
Co-authored-by: Yuan Luo <yuan.luo@hotmail.com>
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: Mohammad Miadh Angkad <mangkad.bsdsba2027@aim.edu>
Co-authored-by: Changyi Yang <112288487+ChangyiYang@users.noreply.github.com>
Co-authored-by: YAMY <74099316+YAMY1234@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Vincent Zhong <207368749+vincentzed@users.noreply.github.com>
Co-authored-by: Ch3ngY1 <91232537+Ch3ngY1@users.noreply.github.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: Jerry Ji <jerryjilol@gmail.com>
Co-authored-by: Todobe <43903496+Todobe@users.noreply.github.com>
Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com>
Co-authored-by: Jinyan Chen <jinyanc@nvidia.com>
Co-authored-by: Koushik Dutta <koush@koushikdutta.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Glen Liu <62917497+glenliu21@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: Lee Nau <lnau@nvidia.com>
Co-authored-by: Yongfei Xu <xuyongfei.xyf@antgroup.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Gaoji Liu <34803073+attack204@users.noreply.github.com>
Co-authored-by: liugaoji.lgj <liugaoji.lgj@alibaba-inc.com>
Co-authored-by: yudian0504 <138860534+yudian0504@users.noreply.github.com>
Co-authored-by: Kartik Ramesh <kartikx2000@gmail.com>
Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com>
Co-authored-by: Minglei Zhu <zminglei@linkedin.com>
Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>
Co-authored-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: ybyang <10629930+whybeyoung@users.noreply.github.com>
Co-authored-by: zhangheng <hzh0425@apache.org>
Co-authored-by: yizhang2077 <1109276519@qq.com>
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
Co-authored-by: Bingxu Chen <Bingxu.Chen@amd.com>
Co-authored-by: yctseng0211 <yctseng@amd.com>
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()

if is_npu:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be _is_npu, otherwise this condition will always be True and will hinder running on other backends.
@Todobe can you fix it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

npu quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants