Skip to content

QMoE CPU Performance Update (Up to 4x on 4-bit)#27364

Merged
tianleiwu merged 12 commits intomainfrom
tlwu/20260216/qmoe_cpu_q4_perf
Feb 21, 2026
Merged

QMoE CPU Performance Update (Up to 4x on 4-bit)#27364
tianleiwu merged 12 commits intomainfrom
tlwu/20260216/qmoe_cpu_q4_perf

Conversation

@tianleiwu
Copy link
Copy Markdown
Contributor

Summary

This change improves QMoE CPU performance by moving more work to prepack time and enabling the DirectQ4 GEMM fast path where appropriate, while preserving an env-var switch for performance/accuracy A/B testing.

This PR introduces:

  • Prepack and cache infrastructure for QMoE expert weights.
  • DirectQ4 packed-B cache built during prepack (instead of mutable runtime cache in Compute()).
  • Fast-path support for block-wise cases (including block size 32 where supported by MLAS Q4 type).
  • Runtime toggle via ORT_USE_MLAS_Q4_GEMM_MOE.
  • Default fast-path policy refined to avoid known accuracy-loss scenarios unless explicitly overridden by env var.
  • Test and benchmark refinements for QMoE CPU validation.

Key Implementation Changes

1. Prepack-time cache build

  • Moves DirectQ4 packed-B cache construction to prepack stage.
  • Removes mutable runtime cache maintenance from Compute().
  • Reduces per-inference overhead and avoids mutable shared cache complexity.

2. Fast path vs fallback

  • Keeps two execution modes:
    • DirectQ4 GEMM fast path (MlasQ4GemmPackB + DirectQ4Gemm cache usage).
    • Fallback path (DequantizePrePacked + MlasGemm).
  • Allows controlled fallback for accuracy-sensitive configurations.

3. Environment variable behavior

  • ORT_USE_MLAS_Q4_GEMM_MOE=1: force fast path when supported.
  • ORT_USE_MLAS_Q4_GEMM_MOE=0: force fallback path.
  • Unset: use default policy that enables fast path unless a known accuracy-loss pattern is detected.

4. Test updates

  • QMoE CPU tests were refined to validate env-var on/off behavior and no-env behavior.
  • Coverage includes parity checks for symmetric/asymmetric, row-wise/block-wise settings.

Benchmark Results (1000 inferences, benchmark_qmoe.py)

Note: PyTorch latency fluctuates across runs and is excluded from conclusions below.

ORT results comparison

Config Baseline ORT Time (ms) Baseline ORT tok/s New ORT Time (env=0) (ms) New ORT tok/s (env=0) New ORT Time (env=1) (ms) New ORT tok/s (env=1)
Medium-4bit 748.594 1.3 237.219 4.2 178.943 5.6
Medium-8bit 209.277 4.8 212.074 4.7 203.882 4.9

ORT speedup vs baseline

Config env=0 speedup vs baseline (time) env=1 speedup vs baseline (time)
Medium-4bit 3.16x faster 4.18x faster
Medium-8bit 0.99x (about flat) 1.03x faster

Accuracy Notes

  • env=1 (forced fast path) provides the best 4-bit performance but may show non-zero max diff in known cases.
  • env=0 (fallback) maintains parity behavior with zero observed max diff in the reported benchmark table.
  • Default no-env policy is designed to avoid known accuracy-loss cases while still enabling fast path where safe.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces prepack-time optimization for QMoE (Quantized Mixture of Experts) CPU operations, achieving up to 4x performance improvement for 4-bit quantization by moving weight preprocessing and DirectQ4 GEMM cache building from runtime to initialization time. The implementation adds environment variable controls for A/B testing between fast-path (DirectQ4) and fallback (dequantize + MlasGemm) execution modes.

Changes:

  • Implements PrePack and UseSharedPrePackedBuffers for weight unpacking and cache building at initialization
  • Adds ORT_USE_MLAS_Q4_GEMM_MOE environment variable for runtime path selection with smart defaults
  • Extends test coverage to validate all execution modes (env=0, env=1, no env) for 4-bit configurations
  • Updates attribute naming from swiglu_interleaved to swiglu_fusion for consistency with operator conventions

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
test_qmoe_cpu.py Adds test expansion for env var modes, updates swiglu_fusion attribute, adds bias support, removes debug code
benchmark_qmoe.py New benchmark file for measuring QMoE throughput across configurations
mlas_q4.h Clarifies documentation for MlasQ4GemmPackB expected data layout ([K, N] format)
debug_node_inputs_outputs_utils.cc Updates debug message to indicate pre-packed tensors may have missing type info
moe_quantization_cpu.h Adds PrePack/UseSharedPrePackedBuffers methods, cache storage, and env var control flags
moe_quantization_cpu.cc Implements weight prepacking, DirectQ4 cache building, environment variable handling, and dual execution paths
moe_helper.h Adds nullable weight tensor handling for prepacked weights with fallback logic

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tianleiwu tianleiwu enabled auto-merge (squash) February 20, 2026 23:24
@tianleiwu tianleiwu merged commit b2a6e69 into main Feb 21, 2026
89 of 90 checks passed
@tianleiwu tianleiwu deleted the tlwu/20260216/qmoe_cpu_q4_perf branch February 21, 2026 08:46
tianleiwu added a commit that referenced this pull request Feb 27, 2026
## Summary
This change improves QMoE CPU performance by moving more work to prepack
time and enabling the DirectQ4 GEMM fast path where appropriate, while
preserving an env-var switch for performance/accuracy A/B testing.

This PR introduces:
- Prepack and cache infrastructure for QMoE expert weights.
- DirectQ4 packed-B cache built during prepack (instead of mutable
runtime cache in `Compute()`).
- Fast-path support for block-wise cases (including block size 32 where
supported by MLAS Q4 type).
- Runtime toggle via `ORT_USE_MLAS_Q4_GEMM_MOE`.
- Default fast-path policy refined to avoid known accuracy-loss
scenarios unless explicitly overridden by env var.
- Test and benchmark refinements for QMoE CPU validation.

## Key Implementation Changes

### 1. Prepack-time cache build
- Moves DirectQ4 packed-B cache construction to prepack stage.
- Removes mutable runtime cache maintenance from `Compute()`.
- Reduces per-inference overhead and avoids mutable shared cache
complexity.

### 2. Fast path vs fallback
- Keeps two execution modes:
- DirectQ4 GEMM fast path (`MlasQ4GemmPackB` + `DirectQ4Gemm` cache
usage).
  - Fallback path (`DequantizePrePacked` + `MlasGemm`).
- Allows controlled fallback for accuracy-sensitive configurations.

### 3. Environment variable behavior
- `ORT_USE_MLAS_Q4_GEMM_MOE=1`: force fast path when supported.
- `ORT_USE_MLAS_Q4_GEMM_MOE=0`: force fallback path.
- Unset: use default policy that enables fast path unless a known
accuracy-loss pattern is detected.

### 4. Test updates
- QMoE CPU tests were refined to validate env-var on/off behavior and
no-env behavior.
- Coverage includes parity checks for symmetric/asymmetric,
row-wise/block-wise settings.

## Benchmark Results (1000 inferences, `benchmark_qmoe.py`)

Note: PyTorch latency fluctuates across runs and is excluded from
conclusions below.

### ORT results comparison
| Config | Baseline ORT Time (ms) | Baseline ORT tok/s | New ORT Time
(env=0) (ms) | New ORT tok/s (env=0) | New ORT Time (env=1) (ms) | New
ORT tok/s (env=1) |
|---|---:|---:|---:|---:|---:|---:|
| Medium-4bit | 748.594 | 1.3 | 237.219 | 4.2 | 178.943 | 5.6 |
| Medium-8bit | 209.277 | 4.8 | 212.074 | 4.7 | 203.882 | 4.9 |

### ORT speedup vs baseline
| Config | env=0 speedup vs baseline (time) | env=1 speedup vs baseline
(time) |
|---|---:|---:|
| Medium-4bit | 3.16x faster | 4.18x faster |
| Medium-8bit | 0.99x (about flat) | 1.03x faster |

## Accuracy Notes
- `env=1` (forced fast path) provides the best 4-bit performance but may
show non-zero max diff in known cases.
- `env=0` (fallback) maintains parity behavior with zero observed max
diff in the reported benchmark table.
- Default no-env policy is designed to avoid known accuracy-loss cases
while still enabling fast path where safe.
tianleiwu added a commit that referenced this pull request Feb 27, 2026
This cherry-picks the following commits for the release:

| Commit ID | PR Number | Commit Title |
|-----------|-----------|-------------|
| decd177 | #27090 | Fix GatherND division by zero when batch
dimensions mismatch |
| 55f8234 | #27360 | Fix QMoE CPU Operator |
| df9146f | #27403 | [MLAS] Adding DynamicQGemm function pointers and
ukernel interface |
| 0f93853 | #27318 | [js/web] Use embedded WASM module in Blob URL
workers when wasmBinary is provided |
| b2a6e69 | #27364 | QMoE CPU Performance Update (Up to 4x on 4-bit)
|
| f501e1d | #27413 | Fix refcount bug in map input conversion that
caused shutdown segfault |
| b32b205 | #27421 | Fix error where bytes is not assigned for
dynamic qgemm pack b size |
| 426b006 | #27397 | Fix DllImportResolver |
| 0982844 | #27412 | MatmulNBits prepacking scales fix |
| 9afb0d2 | #27430 | Fix validation for external data paths for
models loaded from bytes |
| 71d2cd0 | #27401 | Enable Python 3.14 CI and Upgrade Dependencies |
| 79e0676 | #27419 | fix: out of bounds access for resize operation |
| 82eb99c | #27459 | Fix SkipLayerNorm fusion incorrectly applied
when gamma/beta are not 1D |
| 355278a | #27444 | Fix GatherCopyData Integer Truncation Leading to
Heap Out-of-Bounds Read/Write |
| cf96123 | #27411 | [web] fix usage of wasmBinary together with a
blob URL for .mjs |
| 1131a86 | #27399 | [web] remove the unhelpful "Unknown CPU vendor"
warning. |
| ffbbc4f | #27316 | Build Windows ARM64X binaries as part of
packaging pipeline |

---------

Signed-off-by: Jonathan Clohessy <Jonathan.Clohessy@arm.com>
Co-authored-by: patryk-kaiser-ARM <patryk.kaiser@arm.com>
Co-authored-by: don <70039285+0-don@users.noreply.github.com>
Co-authored-by: Jonathan Clohessy <jonathan.clohessy@arm.com>
Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Adrian Lizarraga <adlizarraga@microsoft.com>
Co-authored-by: Lukas Folle <126877803+lukas-folle-snkeos@users.noreply.github.com>
Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com>
Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
Co-authored-by: Chaya <cha182350@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Erik <erscor@microsoft.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants