Skip to content

[MoE] Improve moe kernels for unsloth fine tuning#3812

Merged
danielhanchen merged 28 commits into
unslothai:mainfrom
Datta0:moe_kernels
Feb 5, 2026
Merged

[MoE] Improve moe kernels for unsloth fine tuning#3812
danielhanchen merged 28 commits into
unslothai:mainfrom
Datta0:moe_kernels

Conversation

@Datta0

@Datta0 Datta0 commented Jan 1, 2026

Copy link
Copy Markdown
Collaborator

This might come in handy for unslothai/unsloth-zoo#396
We do prefer using grouped_mm there but this is a fallback for that before going pure pytorch mode
Needs transformers v5 there

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Datta0, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on significantly improving the performance and usability of Mixture-of-Experts (MoE) kernels, particularly for Qwen3 models, within the Unsloth framework. The core changes involve integrating Triton-optimized grouped GEMM operations, implementing a robust caching system for kernel configurations to reduce overhead, and ensuring better compatibility with "torch.compile" by handling tracing and dynamic kernel parameters. These enhancements aim to provide a more efficient and streamlined fine-tuning experience for MoE architectures.

Highlights

  • MoE Kernel Autotuning Cache: Introduced a new caching system to store and reuse auto-tuned configurations for Mixture-of-Experts (MoE) kernels, preventing redundant tuning at runtime.
  • Triton Kernel Compatibility with torch.compile: Enhanced Triton grouped GEMM kernels with the "@allow_in_graph" decorator and a tracing detection mechanism ("_is_tracing") to ensure seamless integration and prevent issues when using "torch.compile".
  • Dynamic Kernel Parameter Handling: Modified Triton kernels to accept "NUM_TOKENS" and "NUM_SMS" as dynamic parameters rather than compile-time constants, and removed "NUM_TOKENS" from autotuning keys to reduce recompilation.
  • Qwen3 MoE Triton Integration: Added a new module ("qwen3_moe_triton.py") to specifically integrate and leverage Triton grouped GEMM kernels for Qwen3 Mixture-of-Experts models, including a fallback mechanism and pre-autotuning during model loading.
  • TMA Support Refinements: Updated Triton Tensor Memory Accelerator (TMA) descriptor creation to use "tl.make_tensor_descriptor" and temporarily disabled TMA due to ongoing compatibility work, while adding auto-detection for TMA parameters.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

N = W.shape[0] // num_experts
assert K == W.shape[1], f"K ({K}) must match W.shape[1] ({W.shape[1]})"

if fuse_mul_post:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore fused mul warning flag initialization

The fused-mul path now references _FUSED_MUL_WARN but the module no longer defines the flag after the TMA refactor, so any call with fuse_mul_post=True hits a NameError before executing the kernel. This is a regression from the previous version where the guard was initialized and will crash inference paths that rely on post-mul fusion.

Useful? React with 👍 / 👎.

Comment thread unsloth/models/qwen3_moe_triton.py Outdated
Comment on lines +179 to +183
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes = num_experts
).permute(2, 1, 0)

token_counts_by_expert = expert_mask.sum(dim = 1).int()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use per-expert token counts for grouped GEMM

In the Triton forward for Qwen3 MoE, token_counts_by_expert is computed as expert_mask.sum(dim=1) which yields a [num_experts, num_tokens] matrix instead of the expected 1‑D counts per expert. The grouped GEMM kernels read m_sizes as a length‑num_experts vector, so this flattened matrix feeds arbitrary per‑token entries into the kernel, producing incorrect routing and outputs whenever the Triton MoE path is enabled.

Useful? React with 👍 / 👎.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable auto-tuning cache system for MoE Triton kernels, which will prevent re-tuning on each run and significantly improve startup times. The changes also enhance torch.compile compatibility and add a new Triton-optimized path for Qwen3-MoE models. While the overall direction is excellent, I've found a critical bug in the token routing logic within the new qwen3_moe_triton.py file that needs to be addressed. Additionally, there are a couple of smaller issues regarding TMA support detection and error message handling that I've flagged for improvement.

Comment thread unsloth/models/qwen3_moe_triton.py Outdated
Comment on lines +43 to +48
def _check_tma_support():
import triton.language as tl

gpu_supports_tma = torch.cuda.get_device_capability()[0] >= 9
# Check for both old experimental and new stable API names
triton_has_tma_api = hasattr(tl, "make_tensor_descriptor") or hasattr(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _check_tma_support can lead to runtime errors. It checks for the existence of either make_tensor_descriptor or _experimental_make_tensor_descriptor, but the kernels now exclusively use make_tensor_descriptor. If a user has a Triton version with only the experimental API, _SUPPORTS_TMA will be True, but the kernel launch will fail with an AttributeError. The check should be made stricter to only hasattr(tl, 'make_tensor_descriptor') to match the kernel's expectation, or the kernels should be updated to handle both API versions.

Suggested change
def _check_tma_support():
import triton.language as tl
gpu_supports_tma = torch.cuda.get_device_capability()[0] >= 9
# Check for both old experimental and new stable API names
triton_has_tma_api = hasattr(tl, "make_tensor_descriptor") or hasattr(
def _check_tma_support():
import triton.language as tl
gpu_supports_tma = torch.cuda.get_device_capability()[0] >= 9
# Kernels now use the stable `make_tensor_descriptor` API.
triton_has_tma_api = hasattr(tl, 'make_tensor_descriptor')
return gpu_supports_tma and triton_has_tma_api

Comment thread unsloth/kernels/moe/autotune_cache.py
@danielhanchen danielhanchen changed the base branch from nightly to main January 30, 2026 12:05
@danielhanchen danielhanchen merged commit c582a63 into unslothai:main Feb 5, 2026
1 check passed
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
* Improve MoE performance

* small changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix imports

* disable autotune

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* LoRA for MoE

* Make autotune default

* make dy contiguous

* use non lora model as base for RL

* Revert "use non lora model as base for RL"

This reverts commit bc8f156.

* fixup derp

* non TMA [T4]

* Revert "non TMA [T4]"

This reverts commit 3530456.

* Fixes for VL MoE and v5 transformers

* [transformers] [v5] remove unused hybridcache (unslothai#3910)

* remote unused hybridcache

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* No double compile for qwen3moe

* Fix top_k on trl GRPO

* Recognise GLM as MoE

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix missing RotaryEmbeddingConfigMixin

* Licensing for autotuning cache

* Cleanup

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Erland366 <erland.pg366@gmail.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants