Skip to content

[MoE] Add LFM2 MoE tuning support + tuned configs for H100/B200/MI325X#22791

Merged
Fridge003 merged 6 commits into
sgl-project:mainfrom
tugot17:lfm2-moe-tuned-configs
Apr 22, 2026
Merged

[MoE] Add LFM2 MoE tuning support + tuned configs for H100/B200/MI325X#22791
Fridge003 merged 6 commits into
sgl-project:mainfrom
tugot17:lfm2-moe-tuned-configs

Conversation

@tugot17
Copy link
Copy Markdown
Contributor

@tugot17 tugot17 commented Apr 14, 2026

Summary

Adds Lfm2MoeForCausalLM to the MoE tuning script and ships tuned fused MoE triton kernel configs for LFM2-8B-A1B and LFM2-24B-A2B at TP=1,2,4,8 on NVIDIA H100, B200, and AMD Instinct MI325X. Up to +47% throughput over default configs at high concurrency on NVIDIA.

Motivation

LFM2 MoE models (LiquidAI/LFM2-8B-A1B, LiquidAI/LFM2-24B-A2B) use num_experts / moe_intermediate_size config keys. The default Mixtral fallback in common_utils.py expects num_local_experts / intermediate_size, so tuning either crashes or produces wrong kernel shapes.

Without tuned configs, the fused MoE triton kernel falls back to generic defaults that are far from optimal for LFM2's expert shapes. Per-rank shard sizes:

Model Experts TP1 N TP2 N TP4 N TP8 N
LFM2-8B-A1B E=32 1792 896 448 224
LFM2-24B-A2B E=64 1536 768 384 192

Changes

  • benchmark/kernels/fused_moe_triton/common_utils.py — 4-line branch for Lfm2MoeForCausalLM that reads num_experts / moe_intermediate_size
  • 16 new JSON configs in python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/ covering H100 and B200 × 8B/24B × TP=1,2,4,8
  • 8 new JSON configs in python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_6_0/ covering MI325X × 8B/24B × TP=1,2,4,8 (AMD environment ships triton 3.6.0)

Results

Peak output throughput at concurrency=8192, scenario D(1024,256) sustained. Server flags: --enable-torch-compile --cuda-graph-max-bs 8192 --disable-radix-cache --mem-fraction-static 0.80 --dtype bfloat16.

nvidia_tuned_vs_baseline_tp1

H100 80GB HBM3

Model TP Default (tok/s) Tuned (tok/s) Improvement
8B 1 13,290 19,610 +47%
8B 2 20,190 25,680 +27%
8B 4 25,430 28,790 +13%
8B 8 29,160 29,600 +2%
24B 1 7,572 10,450 +38%
24B 2 12,800 15,510 +21%
24B 4 16,920 19,990 +18%
24B 8 20,130 21,020 +4%

B200

Model TP Default (tok/s) Tuned (tok/s) Improvement
8B 1 24,840 36,430 +47%
8B 2 30,170 35,330 +17%
8B 4 34,570 46,120 +33%
8B 8 37,320 48,390 +30%
24B 1 16,110 23,070 +43%
24B 2 23,120 24,090 +4%
24B 4 23,800 25,410 +7%
24B 8 32,710 35,000 +7%

MI325X

On AMD, sglang's default moe_runner_backend='auto' routes through aiter CK-MoE, which is faster than the triton fused-MoE and does not use these triton configs. The triton path is only active with --moe-runner-backend triton, which is required for LoRA serving on AMD (aiter CK-MoE does not support LoRA — see python/sglang/srt/layers/moe/moe_runner/runner.py).

With default triton MoE configs the triton path is ~40% slower than aiter. Our tuned configs close that gap almost entirely — at TP=1 the tuned triton path lands within 5–7% of aiter on both models:

Model Concurrency Triton default Triton tuned aiter CK-MoE
8B 8192 9,620 13,770 14,480
24B 8192 5,720 8,064 8,618
mi325x_aiter_vs_triton_tp1

Takeaway: AMD users running default inference see no change (aiter still wins). LoRA-on-AMD users — who are forced onto the triton path — get a big speedup and near-aiter performance.

Verification

Independently verified on a separate H100 node (sglang-from-source install, triton 3.5.1):

  • Configs load correctly at startup for both 8B and 24B:
    Using MoE kernel config from .../triton_3_5_1/E=32,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json
    Using MoE kernel config from .../triton_3_5_1/E=64,N=1536,device_name=NVIDIA_H100_80GB_HBM3.json
    
  • End-to-end chat completions produce sensible output for both models
  • Numerical correctness check (H100, LFM2-8B-A1B): ran with and without the tuned config — logprobs and ROUGE are essentially identical.

Notes

Triton versions

Configs are split across two directories based on the triton version they were tuned against:

  • triton_3_5_1/ — NVIDIA (H100, B200). This is what torch==2.9.1 installs from PyPI on bare-metal Linux x86_64, which matches upstream sglang's default pin. If a future torch bump pulls a different triton version, these configs would need to be retuned against that new version.
  • triton_3_6_0/ — AMD (MI325X). Triton 3.6.0 is what ships inside the lmsysorg/sglang:v0.5.10-rocm720-mi30x container. It also happens to be what lmsysorg/sglang:v0.5.10-cu124 ships on the NVIDIA container side, but we did not retune NVIDIA configs under 3.6.0 for this PR — followup work if containerized NVIDIA deployments need optimal kernels.

Out of scope

  • Down-projection _down.json configs — sglang v0.5.10 supports separate configs for the w2 down-projection via tuning_fused_moe_triton_sep.py (which needs pre-generated topk_ids from a running server). Not tuned here; down-proj falls back to defaults and logs a benign down_moe=False warning at startup. Reported gains are measured with down-proj at defaults, so adding _down.json can only improve these numbers further.

Piotr Mazurek added 3 commits April 13, 2026 16:05
LFM2 MoE models (LiquidAI/LFM2-8B-A1B, LiquidAI/LFM2-24B-A2B) use
num_experts / moe_intermediate_size config keys. The default Mixtral
fallback expects num_local_experts / intermediate_size, so tuning
either crashes or produces wrong kernel shapes. This adds an explicit
branch for Lfm2MoeForCausalLM that reads the correct config fields.
Tuned fused MoE triton kernel configs for LFM2-8B-A1B (E=32) and
LFM2-24B-A2B (E=64) at TP=1,2,4,8 on NVIDIA H100 80GB HBM3 and B200.
Generated via tuning_fused_moe_triton.py grid search over 1920 kernel
configs per batch size (1-4096 + 8192).

Configs auto-load at inference via device_name matching. Delivers up
to +47% throughput over default kernel configs at high concurrency
(tp=1 benefits most; tp=8 shows minimal gains as shards approach
existing config sweet spots).
Tuned fused MoE triton kernel configs for LFM2-8B-A1B and LFM2-24B-A2B
at TP=1,2,4,8 on AMD Instinct MI325X. Generated via
tuning_fused_moe_triton.py inside the v0.5.10 ROCm container which
ships triton 3.6.0 (hence the separate directory from the
NVIDIA configs in triton_3_5_1/).

Note: On AMD, sglang routes MoE through aiter CK-MoE by default,
which does not use these triton configs. The configs take effect
only when --moe-runner-backend triton is set explicitly (e.g. for
LoRA workloads where aiter CK-MoE is unavailable).
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the Lfm2MoeForCausalLM architecture in the MoE Triton kernel benchmark and introduces a comprehensive set of Triton kernel configurations for NVIDIA B200, H100, and AMD MI325X GPUs. The review feedback suggests refactoring the architecture configuration logic to reduce code duplication and ensuring that the new JSON files include a trailing newline for consistency.

Comment on lines +141 to +144
elif architecture == "Lfm2MoeForCausalLM":
E = config.num_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This elif block is identical to the logic for other architectures like BailingMoEForCausalLM (lines 124-131) and Qwen2MoeForCausalLM (lines 73-82). To avoid code duplication and improve maintainability, consider adding Lfm2MoeForCausalLM to one of the existing lists of architectures that share this logic.

"num_stages": 2,
"waves_per_eu": 0
}
} No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file, and the other 7 new JSON configuration files for AMD Instinct MI325X, are missing a final newline character. It's a common convention to end text files with a newline. Please add one for consistency and to prevent potential issues with some tools.

@tugot17
Copy link
Copy Markdown
Contributor Author

tugot17 commented Apr 14, 2026

The question is, are the triton directories correct, and how this could be future-proof when the trition version upgrades.

@Fridge003
Copy link
Copy Markdown
Collaborator

Fridge003 commented Apr 17, 2026

The question is, are the triton directories correct, and how this could be future-proof when the trition version upgrades.

@tugot17 Thanks for your support! When triton upgrades, it will hit to the newest version available

Also please fix lint with pre-commit install && pre-commit run --all-files

@tugot17
Copy link
Copy Markdown
Contributor Author

tugot17 commented Apr 17, 2026

@Fridge003 linting now passes

@tugot17
Copy link
Copy Markdown
Contributor Author

tugot17 commented Apr 19, 2026

@Fridge003

could we merge it now?

@Fridge003 Fridge003 merged commit 6cf0b00 into sgl-project:main Apr 22, 2026
56 of 64 checks passed
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
Wen-xuan-Xu added a commit to Wen-xuan-Xu/sglang that referenced this pull request Apr 29, 2026
After sgl-project#23019 moved the MoE config loader and the configs/ tree from
`fused_moe_triton/` to `moe_runner/triton_utils/`, two later PRs
unknowingly added 33 tuned-config JSONs to the OLD path:

- sgl-project#22791 (LFM2)        — 24 files (E=32/64, H100/B200/MI325X)
- sgl-project#23533 (Hy3 preview) —  9 files (E=192,N=192 incl. _down,
                                    H20/H20-3e/B200)

The runtime loader anchors its search via
os.path.dirname(os.path.realpath(__file__)) of the loader file
(now in moe_runner/triton_utils/), so configs in the old
directory were never read — runtime fell back to
get_default_config().

The configs themselves were properly tuned and benchmarked at
submission time via the in-process override_config() path used
by the tuning script — that is why the PR authors observed real
speedup. The bug is purely a wrong filesystem location.

Root cause: the tuning README still pointed contributors to the
old path. This PR moves the misplaced configs into the
runtime-loaded location and fixes the README.

Changes:
  * R100 git-mv 33 JSONs into moe_runner/triton_utils/configs/{triton_3_5_1,triton_3_6_0}/
  * Update benchmark/kernels/fused_moe_triton/README.md path

No content changes. No code changes.

References: sgl-project#23019 sgl-project#22791 sgl-project#23533
Wen-xuan-Xu added a commit to Wen-xuan-Xu/sglang that referenced this pull request Apr 29, 2026
After sgl-project#23019 moved the MoE config loader and the configs/ tree from
`fused_moe_triton/` to `moe_runner/triton_utils/`, two later PRs
unknowingly added 33 tuned-config JSONs to the OLD path:

- sgl-project#22791 (LFM2)        — 24 files (E=32/64, H100/B200/MI325X)
- sgl-project#23533 (Hy3 preview) —  9 files (E=192,N=192 incl. _down,
                                    H20/H20-3e/B200)

The runtime loader anchors its search via
os.path.dirname(os.path.realpath(__file__)) of the loader file
(now in moe_runner/triton_utils/), so configs in the old
directory were never read — runtime fell back to
get_default_config().

The configs themselves were properly tuned and benchmarked at
submission time via the in-process override_config() path used
by the tuning script — that is why the PR authors observed real
speedup. The bug is purely a wrong filesystem location.

Root cause: the tuning README still pointed contributors to the
old path. This PR moves the misplaced configs into the
runtime-loaded location and fixes the README.

Changes:
  * R100 git-mv 33 JSONs into moe_runner/triton_utils/configs/{triton_3_5_1,triton_3_6_0}/
  * Update benchmark/kernels/fused_moe_triton/README.md path

No content changes. No code changes.

References: sgl-project#23019 sgl-project#22791 sgl-project#23533
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants