Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/optimizations.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,12 @@ Provides efficient Triton kernels to improve training speed and reduce memory us

### Expert Kernels

Optimized kernel implementations for Mixture of Experts (MoE) model training.
Optimized per-expert grouped-GEMM kernels for MoE training, with LoRA support.

- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support.
- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs.
- **ScatterMoE**: Triton, any CUDA GPU.
- **SonicMoE**: CUTLASS / cute-DSL, Hopper+ only.

- **Config:** `use_scattermoe: true` or `use_sonicmoe: true`
- **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration)

## Long Context Models
Expand Down Expand Up @@ -117,7 +118,7 @@ To train models that don't fit on a single GPU, you'll need to use a distributed

### N-D Parallelism (Beta)

For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence, Expert Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.

- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)

Expand Down
9 changes: 1 addition & 8 deletions src/axolotl/integrations/expert_parallel/experts_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _grouped_mm_local(experts, recv_x, recv_topk_idx, recv_topk_weights):


def _scattermoe_local(experts, recv_x, recv_topk_idx, recv_topk_weights):
from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import (
from axolotl.integrations.kernels.libs.scattermoe_lora.experts import (
scattermoe_experts_forward,
)

Expand All @@ -68,13 +68,6 @@ def _scattermoe_local(experts, recv_x, recv_topk_idx, recv_topk_weights):

def _sonicmoe_local(experts, recv_x, recv_topk_idx, recv_topk_weights):
raise NotImplementedError("Sonicmoe + EP is not yet properly implemented.")
# from axolotl.integrations.kernels.libs.sonicmoe.gemma4_experts import (
# gemma4_sonicmoe_experts_forward,
# )

# return gemma4_sonicmoe_experts_forward(
# experts, recv_x, recv_topk_idx, recv_topk_weights
# )


_LOCAL_KERNELS = {
Expand Down
174 changes: 60 additions & 114 deletions src/axolotl/integrations/kernels/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Kernels Integration

MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. Transformers v5 introduced a uniform dispatch point for the per-expert grouped GEMMs via the `experts_implementation` config kwarg:

```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
"sonicmoe": sonicmoe_experts_forward, # upstream HF integration
}
```

In our custom integration, we add support for **ScatterMoE** and **SonicMoE**, which are more efficient and faster than `grouped_mm`.
Axolotl registers two additional implementations into this same global registry: **ScatterMoE** (Triton, runs on any CUDA GPU) and a LoRA-aware **SonicMoE** variant (CUTLASS / cute-DSL, Hopper or newer). Routing — softmax/sigmoid top-k, group selection, shared experts, bias correction, etc. — stays in each model's `SparseMoEBlock`, where transformers handles all per-architecture variation. Axolotl only swaps the experts forward.

## Usage

Expand All @@ -28,130 +29,75 @@ use_scattermoe: true
use_sonicmoe: true
```

**Important:** Setting `experts_implementation` to `batched_mm` or `grouped_mm` is incompatible with custom kernel options. The exception is `experts_implementation: scattermoe`, which is used for models like Gemma 4 that embed MoE directly in the decoder layer (no SparseMoeBlock) and dispatch through the transformers `ExpertsInterface`.
`experts_implementation` is auto-set to `scattermoe` / `sonicmoe` from the kernel flag, but you can override to `eager` / `batched_mm` / `grouped_mm` to compare against the transformers reference implementations.

### SonicMoE installation

**Prerequisites:**
- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU
- NVIDIA Hopper (H100/H200) or Blackwell (B200/GB200/B300) GPU
- CUDA 12.9+ (13.0+ for B300)
- PyTorch 2.7+ (2.9.1 recommended)
- For B300: Triton 3.6.0
- PyTorch 2.7+
- For B300: Triton 3.6.x

The sonic-moe kernel ships through the HF [`kernels`](https://github.com/huggingface/kernels) package. Transformers v5.8+ auto-fetches a prebuilt kernel from [`kernels-community/sonic-moe`](https://huggingface.co/kernels-community/sonic-moe) on first use:

```bash
pip install --ignore-requires-python --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5
pip install kernels "nvidia-cutlass-dsl==4.4.2"
```

See the [SonicMoE installation guide](https://github.com/Dao-AILab/sonic-moe?tab=readme-ov-file#-installation) for the latest prerequisite details.

**Note:** Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels.
**Note:** Blackwell support is in upstream beta. On Blackwell GPUs Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels.

## How It Works

The `KernelsPlugin` runs before model loading and:

### ScatterMoE
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation via the HF `kernels` library.

### SonicMoE
1. Resolves the model's MoE block class(es) from `constants.py`.
2. Patches the forward method with SonicMoE's optimized CUTLASS kernels and registers a weight converter for the interleaved gate/up projection format.
3. Supports pluggable routing strategies (see routing table below).

Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.

## Model Support Matrix

Most models use the **SwiGLU** activation (`silu(gate) * up`). Gemma 4 uses **GEGLU** (`gelu(gate) * up`). ScatterMoE supports any gated activation (activation is applied in Python between kernel calls). SonicMoE supports SwiGLU, GEGLU, and REGLU via its `ActivationType` enum.

### Routing strategies

| Routing Strategy | Description | ScatterMoE | SonicMoE |
|---|---|:---:|:---:|
| softmax → topk | Softmax over experts, select top-K, optional renormalization | Yes | Yes |
| softmax → group selection → topk | Softmax, select top groups (sum of top-2 per group), topk from selected groups, renorm + scaling | No | Yes |
| sigmoid → topk (with groups) | Sigmoid + bias correction, group-based masking, topk from masked scores, weights from original sigmoid | Yes | Yes |
| sigmoid → topk (no groups) | Sigmoid + bias correction, straight topk (n_group=1) | Yes | Yes |
| softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes |
| softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes |
| softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes |
| softmax → topk + per_expert_scale | RMSNorm → scale → proj → softmax → topk → renorm → per-expert learned scales | Yes | Yes |
| fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned |

### Per-model support

| Model Type | Architecture | Routing | ScatterMoE | SonicMoE |
|---|---|---|:---:|:---:|
| `qwen2_moe` | Qwen2-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_moe` | Qwen3-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_5_moe` | Qwen3.5-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_5_moe_text` | Qwen3.5-MoE (VLM text) | softmax → topk | **Yes** | **Yes** |
| `qwen3_next` | Qwen3-Next | softmax → topk | **Yes** | **Yes** |
| `qwen3_vl_moe` | Qwen3-VL-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_omni_moe` | Qwen3-Omni (Thinker + Talker) | softmax → topk | **Yes** | **Yes** |
| `olmoe` | OLMoE | softmax → topk | **Yes** | **Yes** |
| `mixtral` | Mixtral | softmax → topk | **Yes** | **Yes** |
| `minimax` | MiniMax | softmax → topk | **Yes** | **Yes** |
| `mistral4` | Mistral 4 | softmax → group → topk | No | **Yes** |
| `glm_moe_dsa` | GLM-MoE DSA (GLM 5) | sigmoid → topk (groups) | **Yes** | **Yes** |
| `deepseek_v3` | DeepSeek-V3 | sigmoid → topk (groups) | **Yes** | **Yes** |
| `glm4_moe` | GLM4-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
| `glm4_moe_lite` | GLM4-MoE Lite (GLM 4.7 Flash) | sigmoid → topk (groups) | **Yes**\* | **Yes** |
| `glm4v_moe` | GLM4v-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
| `minimax_m2` | MiniMax M2 | sigmoid → topk (no groups) | **Yes** | **Yes** |
| `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** |
| `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** |
| `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** |
| `gemma4_text` | Gemma 4 (26B-A4B) | softmax → topk + per_expert_scale | **Yes**\*\* | **Yes**\*\* |
| `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned |

\* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations.

\*\* Gemma 4 uses `experts_implementation: scattermoe` path (registered via `ExpertsInterface`) instead of SparseMoeBlock patching, since Gemma 4 embeds MoE directly in its decoder layer (no separate SparseMoeBlock). See the [Gemma 4 section](#gemma-4) below.

### Feature comparison

| Feature | ScatterMoE | SonicMoE |
|---|:---:|:---:|
| Kernel backend | Triton | CUTLASS |
| GPU requirement | Any CUDA | Hopper (H100/H200) or Blackwell (B200+) |
| LoRA approach | Fused in Triton kernel | Runtime materialization + custom autograd |
| LoRA overhead | Lower (fused computation) | Higher (per-forward materialization) |
| Gate/router LoRA | Yes | Yes |
| Expert LoRA | Yes (fused) | Yes (materialized) |
| Shared expert LoRA | Yes (standard PEFT) | Yes (standard PEFT) |
| Selective expert dequantization | Yes (~97% memory savings) | No |
| Weight format | Transposed `[E, hidden, 2*inter]` | Interleaved gate/up `[2*I, H, E]` |
| torch.compile routing | No | Yes (optional) |

## Shared Expert Handling

Both kernels handle shared experts identically. Shared expert attribute names are detected in order of priority:

1. `shared_expert` (Qwen2-MoE)
2. `shared_experts` (GLM-MoE, DeepSeek-V3)
3. `shared_mlp` (HunYuan V1 MoE)

If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed.

## Gemma 4

Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:

- **No SparseMoeBlock**: MoE is embedded directly in the decoder layer alongside a dense MLP. Both run in parallel and their outputs are summed.
- **Custom router** (`Gemma4TextRouter`): RMSNorm → learned scale → linear projection → softmax → top-k → renormalization → per-expert learned scales.
- **GEGLU activation**: Uses `gelu_pytorch_tanh` (not SiLU/SwiGLU like most other MoE models).
- **128 experts, top-k=8** for the 26B-A4B variant.

Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.

## Limitations

- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
- **Non-SwiGLU activations**: Neither kernel supports MoE architectures with non-SwiGLU expert activations (e.g., GPT-OSS uses a custom GLU variant).
- **GPT-OSS**: Deferred — requires transposed weight layout `[E, H, 2*I]`, expert biases, and custom GLU activation. A dedicated forward path is needed.
- **FSDP + fused gate LoRA (SonicMoE)**: The fused topk→softmax path materializes a local tensor when LoRA delta is present to avoid DTensor + Tensor mixing under FSDP.
The `KernelsPlugin` runs once before model loading and:

1. Calls `register_scattermoe_experts()` or `register_sonicmoe_experts()`, which inserts the kernel forward into `transformers.integrations.moe.ALL_EXPERTS_FUNCTIONS`.
2. Sets `cfg.experts_implementation` to the matching name.
3. When the model loads, transformers' `@use_experts_implementation` decorator on each model's `Experts` class reads `config._experts_implementation` and dispatches to our registered forward.

That's the entire integration — there is no per-architecture SparseMoEBlock monkey-patch, no per-model routing code, and no weight-layout conversion. As new MoE models adopt the decorator upstream they immediately benefit from both kernels.

## LoRA Support

Both kernels train PEFT adapters on `gate_up_proj` / `down_proj` (and `gate` for the router) end-to-end:

- **ScatterMoE** fuses the LoRA `B @ A` product into the per-expert grouped GEMM via custom Triton kernels (`parallel_linear_lora`). No extra materialization pass.
- **SonicMoE** materializes `W_eff = W + scaling * (B @ A)` per expert inside a custom `MoELoRAMaterialize` `autograd.Function` and passes the effective weight into the CUTLASS kernel. Backward decomposes `dW_eff` into `dA` and `dB` via the chain rule, so LoRA parameters train without modifying the kernel.

Both paths detect PEFT `ParamWrapper` on individual expert parameters (`target_parameters` API) and unwrap them before dispatch.

## Model Support

Any model whose `Experts` class is decorated with `@use_experts_implementation` upstream works automatically. As of transformers 5.8 this includes (verified):

| Model Type | ScatterMoE | SonicMoE |
|-------------------|:---------:|:--------:|
| `mixtral` | Yes | Yes |
| `qwen2_moe` | Yes | Yes |
| `qwen3_moe` | Yes | Yes |
| `qwen3_5_moe` | Yes | Yes |
| `olmoe` | Yes | Yes |
| `mistral4` | Yes | Yes |
| `glm_moe_dsa` | Yes | Yes |
| `deepseek_v3` | Yes | Yes |
| `minimax_m2` | Yes | Yes |
| `ernie4_5_moe` | Yes | Yes |
| `hunyuan_v1_moe` | Yes | Yes |
| `gemma4_text` | Yes | Yes |
| `gpt_oss` | No | Yes |

`gpt_oss` carries the decorator with `is_concatenated=False, is_transposed=True, has_bias=True` and uses a sigmoid-GLU activation with clamping. The SonicMoE forward reads these flags off `self` and dispatches accordingly. The ScatterMoE forward assumes the standard `[E, 2*I, H]` concat layout and SiLU-GLU without bias, so it does not yet support `gpt_oss`.

## Feature comparison

| Feature | ScatterMoE | SonicMoE |
|----------------------------------|:----------:|:--------:|
| Kernel backend | Triton | CUTLASS / cute-DSL |
| GPU requirement | Any CUDA | Hopper+ |
| LoRA path | Fused in Triton kernel | `MoELoRAMaterialize` + custom autograd |
| LoRA overhead | Lower (fused) | Higher (materialization pass) |
| Selective expert dequantization | Yes (~97% memory savings) | No |
| Weight format | Standard `[E, 2*I, H]` | Standard `[E, 2*I, H]` (concat layout, no interleave) |

## Note on MegaBlocks

Expand Down
63 changes: 56 additions & 7 deletions src/axolotl/integrations/kernels/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
LOG = get_logger(__name__)


# Valid experts_implementation values:
# - "eager" : transformers' per-token loop reference implementation
# - "batched_mm" : transformers' built-in batched matmul path
# - "grouped_mm" : transformers' built-in grouped matmul path (cache-efficient)
# - "scattermoe" : axolotl-registered Triton kernels with LoRA support
# - "sonicmoe" : axolotl-registered CUTLASS / cute-DSL kernels with LoRA support
# - "deep_ep[_*]": EP-plugin composites; passed through when expert_parallel_size > 1
_BUILTIN_EXPERTS_IMPLS = {"eager", "batched_mm", "grouped_mm"}
_KERNEL_EXPERTS_IMPLS = {"scattermoe", "sonicmoe"}
_EP_EXPERTS_IMPLS = {
"deep_ep",
"deep_ep_grouped_mm",
"deep_ep_scattermoe",
"deep_ep_sonicmoe",
}
_VALID_EXPERTS_IMPLS = (
_BUILTIN_EXPERTS_IMPLS | _KERNEL_EXPERTS_IMPLS | _EP_EXPERTS_IMPLS
)


class KernelsArgs(BaseModel):
use_scattermoe: bool | None = None
use_sonicmoe: bool | None = None
Expand All @@ -30,24 +50,53 @@ def check_use_kernels(cls, data):

return data

@model_validator(mode="before")
@classmethod
def check_sonicmoe_ep_unsupported(cls, data):
"""SonicMoE + EP is not yet implemented (EP `_sonicmoe_local` raises)."""
if data.get("use_sonicmoe") and (data.get("expert_parallel_size") or 1) > 1:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

rg -n 'bool\(data\.get\("use_scattermoe"\)\)|bool\(data\.get\("use_sonicmoe"\)\)|\(data\.get\("expert_parallel_size"\) or 1\) > 1' src/axolotl/integrations/kernels/args.py

python - <<'PY'
print("bool('false') =", bool("false"))
print("bool('0') =", bool("0"))
for v in ["2", 2, None]:
    try:
        print(f"{v!r}: {(v or 1) > 1}")
    except Exception as e:
        print(f"{v!r}: {type(e).__name__} -> {e}")
PY

Repository: axolotl-ai-cloud/axolotl

Length of output: 405


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="src/axolotl/integrations/kernels/args.py"

# Show the relevant surrounding code and any long comment blocks near the referenced line ranges
python - <<'PY'
import itertools, pathlib
path = pathlib.Path("src/axolotl/integrations/kernels/args.py")
lines = path.read_text().splitlines()
for start, end in [(1, 35), (45, 90), (70, 95), (75, 110)]:
    print(f"\n--- {start}-{end} ---")
    for i in range(start, min(end, len(lines)) + 1):
        print(f"{i:4d}: {lines[i-1]}")
PY

Repository: axolotl-ai-cloud/axolotl

Length of output: 6985


🏁 Script executed:

#!/bin/bash
set -euo pipefail

rg -n 'use_(sonicmoe|scattermoe)\s*:\s*' . -g'*.yml' -g'*.yaml' -g'*.json' || true
rg -n 'expert_parallel_size\s*:\s*' . -g'*.yml' -g'*.yaml' || true
rg -n 'model_validator\(mode="before"\)|mode="before"' src/axolotl/integrations/kernels/args.py src/axolotl -S || true
rg -n 'str\(.+\)\.lower\(\).*(true|false)|parse.*bool|coerce.*bool|in \("true"' src/axolotl || true

Repository: axolotl-ai-cloud/axolotl

Length of output: 10349


Avoid pre-coercion truthiness/type issues in before-validators

  • src/axolotl/integrations/kernels/args.py line 57 and lines 69-70 can mis-handle raw string inputs: bool("false") evaluates to True, and ("2" or 1) > 1 raises TypeError (str vs int) before coercion.
  • Multi-line comment blocks at lines 8-14 and 78-80 violate the “comments should be max one short line” guideline (they document option values / behavior rather than “WHY”).
🐛 Proposed fix
@@
     def check_sonicmoe_ep_unsupported(cls, data):
         """SonicMoE + EP is not yet implemented (EP `_sonicmoe_local` raises)."""
-        if data.get("use_sonicmoe") and (data.get("expert_parallel_size") or 1) > 1:
+        ep_size_raw = data.get("expert_parallel_size")
+        try:
+            ep_size = int(ep_size_raw) if ep_size_raw is not None else 1
+        except (TypeError, ValueError):
+            ep_size = 1
+        if data.get("use_sonicmoe") is True and ep_size > 1:
             raise ValueError(
                 "use_sonicmoe=true is not supported with expert_parallel_size > 1. "
                 "Use use_scattermoe=true under EP, or set expert_parallel_size=1."
             )
         return data
@@
-        use_scattermoe = bool(data.get("use_scattermoe"))
-        use_sonicmoe = bool(data.get("use_sonicmoe"))
+        use_scattermoe = data.get("use_scattermoe") is True
+        use_sonicmoe = data.get("use_sonicmoe") is True
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/kernels/args.py` at line 57, The before-validator
currently evaluates raw dict values directly (e.g., data.get("use_sonicmoe") and
(data.get("expert_parallel_size") or 1) > 1) which mis-handles string inputs
like "false" or "2" and can raise TypeError; update the before-validator to
explicitly coerce/parse these fields first (e.g., parse use_sonicmoe into a
boolean using a safe string-to-bool helper and coerce expert_parallel_size to
int with a fallback) and then perform the comparison, adjusting any helper or
validator functions used by the validator; also replace the multi-line comment
blocks that document option values with short single-line comments (split or
rephrase the content around the validator and the option docs) to comply with
the one-line comment guideline.

raise ValueError(
"use_sonicmoe=true is not supported with expert_parallel_size > 1. "
"Use use_scattermoe=true under EP, or set expert_parallel_size=1."
)
return data

@model_validator(mode="before")
@classmethod
def check_experts_implementation(cls, data):
"""Auto-select impl from kernel flags; reject mismatched/unknown values."""
experts_implementation = data.get("experts_implementation")
use_scattermoe = data.get("use_scattermoe", False)
use_scattermoe = bool(data.get("use_scattermoe"))
use_sonicmoe = bool(data.get("use_sonicmoe"))

if experts_implementation is None:
# transformers may default to batched_mm when unset
data["experts_implementation"] = "eager"
elif experts_implementation == "scattermoe" and not use_scattermoe:
if use_scattermoe:
data["experts_implementation"] = "scattermoe"
elif use_sonicmoe:
data["experts_implementation"] = "sonicmoe"
else:
# Transformers defaults to a non-eager backend when unset; pin to
# eager unless the user explicitly opts in.
data["experts_implementation"] = "eager"
return data

if experts_implementation == "scattermoe" and not use_scattermoe:
LOG.warning(
"`experts_implementation='scattermoe'` requires `use_scattermoe: true`. "
"Automatically setting to 'eager'."
)
data["experts_implementation"] = "eager"
elif experts_implementation not in ("eager", "scattermoe"):
elif experts_implementation == "sonicmoe" and not use_sonicmoe:
LOG.warning(
"`experts_implementation='sonicmoe'` requires `use_sonicmoe: true`. "
"Automatically setting to 'eager'."
)
data["experts_implementation"] = "eager"
elif experts_implementation not in _VALID_EXPERTS_IMPLS:
LOG.warning(
f"`experts_implementation={experts_implementation!r}` is not compatible with "
f"custom MoE kernels. Automatically setting to 'eager'."
f"`experts_implementation={experts_implementation!r}` is not recognized. "
f"Valid options: {sorted(_VALID_EXPERTS_IMPLS)}. "
f"Automatically setting to 'eager'."
)
data["experts_implementation"] = "eager"

Expand Down
Loading
Loading