Skip to content

Conversation

fajin-corp and others added 15 commits May 1, 2025 00:09
### Description
Add 8bits support for matmulnbits on x86

__AVX512 VNNI__
| M | N | K | 8-bit Time (ns) | 4-bit Time (ns) | Slow down (8-bit /
4-bit) |

|:-----:|:-------:|:-------:|:----------------:|:----------------:|:------------------------:|
| 1 | 4096 | 4096 | 34145 | 27723 | **1.23×** |
| 1 | 11008 | 4096 | 415285 | 68656 | **6.05×** |
| 1 | 4096 | 11008 | 407801 | 68061 | **5.99×** |
| 1 | 11008 | 11008 | 2674538 | 1003532 | **2.67×** |
| 4096 | 4096 | 4096 | 80338759 | 86321713 | **0.93×** |
| 4096 | 11008 | 4096 | 213421935 | 225245276 | **0.95×** |
| 4096 | 4096 | 11008 | 240164365 | 228966953 | **1.05×** |
| 4096 | 11008 | 11008 | 628352046 | 596738340 | **1.05×** |

__AVX512__
| M | N | K | 8-bit Time (ns) | 4-bit Time (ns) | Slow down (8-bit /
4-bit) |

|:-----:|:-------:|:-------:|:----------------:|:----------------:|:------------------------:|
| 1 | 4096 | 4096 | 53324 | 37882 | **1.41×** |
| 1 | 11008 | 4096 | 244560 | 103255 | **2.37×** |
| 1 | 4096 | 11008 | 435131 | 95734 | **4.55×** |
| 1 | 11008 | 11008 | 2790710 | 1075216 | **2.60×** |
| 4096 | 4096 | 4096 | 200629000 | 132841540 | **1.51×** |
| 4096 | 11008 | 4096 | 532141914 | 350613184 | **1.52×** |
| 4096 | 4096 | 11008 | 544011977 | 351679619 | **1.55×** |
| 4096 | 11008 | 11008 | 1421865147 | 925593210 | **1.54×** |

Token generation is bottlenecked at memory access. 8b model's 2x size is
major reason of token generation slow down.

For non-vnni platform, the i16 cannot fit in 4 i8. To avoid overflow
extra instructions are needed. This is the major reason of non-vnni slow
down.

### Motivation and Context
MatMul4Bits model has repetition issue. 6b model resolved this issue.
### Description
Support 8 bits in MatMulNBits cuda kernel.

The `MatMulFloat8bKernel` CUDA kernel performs a matrix-vector
multiplication (GEMM) where the matrix B is quantized per block using
8-bit integers.

The kernel computes $Output = A \times B$, where:
* $A$ is a row vector (shape `[M, K]`) of type `T` (`float` or `half`).
* $B$ is a matrix (shape `[K, N]`) quantized using 8-bit unsigned
integers (`uint8_t`) with a block structure. It's stored as `[N,
K/block_size, block_size]`.
* `scales_data` contains the dequantization scales (shape `[N,
K/block_size]`).
* `zero_points` contains the dequantization zero points (shape `[N,
K/block_size]`), if used (`has_zero_point` is true).
* `output` is the resulting row vector (shape `[M, N]`).

The kernel uses a thread block structure of `(kWarpSize,
kColsPerThreadBlock)`, meaning each block handles `kColsPerThreadBlock`
(which is 8) columns of the output. Each warp within the block is
responsible for one output element (`[m_id, n_id]`). Threads within a
warp cooperate to compute the dot product along the K dimension. Each
thread (`lane_id`) handles `kElementsPerThreadPerIteration` (which is 8)
elements of the K dimension in each step.

Here's a breakdown of the three algorithms (`kKernelAlgo`):

1.  **`kKernelAlgo = 0` (Unrolling):**
* **Strategy:** This algorithm processes the K dimension by iterating in
large steps (`k_per_iter = kWarpSize * kElementsPerThreadPerIteration =
32 * 8 = 256`). Inside the main loop, it uses a macro
(`UnRollReduction`) with `#pragma unroll` directives to aggressively
unroll the innermost computations. It tries unrolling factors of 16, 4,
and 1 sequentially to cover as much of the K dimension as possible with
unrolled code.
* **Pros:** Can significantly reduce loop overhead (branching
instructions, counter updates) and expose more instruction-level
parallelism, potentially hiding memory latency.
* **Cons:** Can lead to a large increase in compiled code size (register
pressure, potential instruction cache misses). The effectiveness heavily
depends on the compiler and the specific GPU architecture. The
multi-stage unrolling adds complexity. It requires `k_per_iter` to be a
multiple of `block_size` for correct scale/zp indexing within the
unrolled loop.
* **Performance Expectation:** Potentially the highest performance *if*
the unrolling is effective on the target hardware and doesn't cause
resource issues (registers, cache). Often good for compute-bound or
latency-bound scenarios where loop overhead is a bottleneck.

2.  **`kKernelAlgo = 1` (Simple Loop):**
* **Strategy:** This algorithm also iterates along the K dimension in
steps of `k_per_iter` (256), but uses a simple `for` loop without
explicit `#pragma unroll`. It relies on the compiler's default loop
optimization capabilities.
* **Pros:** Simpler code, smaller code size compared to Algorithm 0.
Less likely to cause register pressure or instruction cache issues.
Easier for the compiler to reason about.
* **Cons:** May incur higher loop overhead compared to effective
unrolling. Performance might be lower if loop overhead is significant.
* **Performance Expectation:** A solid baseline. Might be close to
Algorithm 0 if the compiler performs implicit unrolling effectively, or
faster if Algorithm 0 suffers from code bloat penalties.

3.  **`kKernelAlgo = 2` (Block Size Iteration):**
* **Strategy:** This algorithm changes the iteration strategy
fundamentally. Instead of iterating in fixed steps of `k_per_iter`, it
iterates based on the quantization `block_size`. The outer loop runs
`blocks_per_K` (`K / block_size`) times. Inside this loop, the scale and
zero point for the *entire block* are fetched once per warp. Then, each
thread checks if its assigned K-elements (`lane_offset`) fall within the
current `block_size` chunk and processes them using the fetched
scale/zp.
* **Pros:** Directly aligns with the block quantization data structure.
Fetches scale/zero-point values less frequently (once per `block_size`
chunk per warp), potentially reducing shared memory bank conflicts or
register usage compared to calculating the index (`current_meta_k`) in
every inner step as in Algo 0/1. Might have better memory access
patterns for scale/zp data.
* **Cons:** The outer loop iterates `K / block_size` times. If
`block_size` is small (e.g., 16, 32), this could be many iterations. The
logic inside the loop (`if (current_k_base < k_end_block ...)`) adds
conditional execution.
* **Performance Expectation:** Performance depends heavily on the
`block_size`. If `block_size` is large (e.g., 128, 256), the number of
outer loop iterations is small, and the efficiency gain from fetching
scale/zp once per block might outweigh the overhead. If `block_size` is
small, the overhead of the outer loop might dominate.

**Next Step:**

1. **Profile:** The most reliable way is to benchmark all three
algorithms (`kKernelAlgo = 0, 1, 2`) on your target GPU hardware with
representative input sizes (`N`, `K`), data types (`T`), and
`block_size` values. Use profiling tools like NVIDIA Nsight Compute to
analyze performance metrics (execution time, occupancy, instruction
throughput, memory bandwidth, cache hit rates, register spills).
2.  **Hypothesize based on `block_size`:**
* For **large `block_size`** (e.g., 128, 256), Algorithm 2 might be
competitive or even the best due to efficient scale/ZP handling.
Algorithm 0 could also be very fast.
* For **small `block_size`** (e.g., 16, 32), Algorithm 0 (unroll) or
Algorithm 1 (simple loop) might outperform Algorithm 2 due to lower loop
overhead in the K dimension.
3. Compare performance with TRT LLM FpA IntB GEMM.

### Motivation and Context
4 bits has accuracy loss for some LLM, need more bits for some layers.
### Description
1. Add benchmark script for MatMulNBits. 
2. Update kernel based on benchmark results:
  - Change kernel back to handle m=1
  - Use simple loop kernel instead of unrolling
- Change partial sum to float type to trade-off precision and
performance (less precision loss, no obvious performance drop)

Example output of benchmark:
```
------------------------------------------------------------------------------------------------------------------------
Benchmarking MatMulNBits on NVIDIA A100-SXM4-80GB (Compute Capability: 8.0)
------------------------------------------------------------------------------------------------------------------------
CUDA Graph   | M        | N        | K        | Bits   | Block Size | Threads  | Latency (us)    | StdDev (us)  | TFLOPS
------------------------------------------------------------------------------------------------------------------------
True         | 1        | 3072     | 8192     | 4      | 32         | 0        | 95.7            | 5.7          | 0.526
True         | 1        | 3072     | 8192     | 8      | 32         | 0        | 110.7           | 81.1         | 0.454
True         | 1        | 3072     | 8192     | 4      | 128        | 0        | 93.7            | 41.2         | 0.537
True         | 1        | 3072     | 8192     | 8      | 128        | 0        | 105.0           | 129.3        | 0.479
True         | 1        | 5120     | 3072     | 4      | 32         | 0        | 86.7            | 49.9         | 0.363
True         | 1        | 5120     | 3072     | 8      | 32         | 0        | 90.1            | 41.1         | 0.349
True         | 1        | 5120     | 3072     | 4      | 128        | 0        | 83.9            | 46.7         | 0.375
True         | 1        | 5120     | 3072     | 8      | 128        | 0        | 85.2            | 57.1         | 0.369
True         | 1        | 8192     | 3072     | 4      | 32         | 0        | 107.3           | 29.2         | 0.469
True         | 1        | 8192     | 3072     | 8      | 32         | 0        | 102.3           | 57.1         | 0.492
True         | 1        | 8192     | 3072     | 4      | 128        | 0        | 99.2            | 61.2         | 0.507
True         | 1        | 8192     | 3072     | 8      | 128        | 0        | 97.5            | 47.4         | 0.516
True         | 1        | 200064   | 3072     | 4      | 32         | 0        | 1456.4          | 11.0         | 0.844
True         | 1        | 200064   | 3072     | 8      | 32         | 0        | 1336.4          | 10.3         | 0.920
True         | 1        | 200064   | 3072     | 4      | 128        | 0        | 1261.6          | 16.6         | 0.974
True         | 1        | 200064   | 3072     | 8      | 128        | 0        | 1232.6          | 17.9         | 0.997
True         | 256      | 3072     | 8192     | 4      | 32         | 0        | 211.1           | 5.8          | 61.030
True         | 256      | 3072     | 8192     | 8      | 32         | 0        | 217.8           | 62.8         | 59.154
True         | 256      | 3072     | 8192     | 4      | 128        | 0        | 208.7           | 63.3         | 61.751
True         | 256      | 3072     | 8192     | 8      | 128        | 0        | 213.0           | 58.2         | 60.491
True         | 256      | 5120     | 3072     | 4      | 32         | 0        | 151.9           | 57.4         | 53.028
True         | 256      | 5120     | 3072     | 8      | 32         | 0        | 156.2           | 71.1         | 51.554
True         | 256      | 5120     | 3072     | 4      | 128        | 0        | 151.4           | 22.6         | 53.198
True         | 256      | 5120     | 3072     | 8      | 128        | 0        | 154.6           | 47.1         | 52.092
True         | 256      | 8192     | 3072     | 4      | 32         | 0        | 219.0           | 4.4          | 58.847
True         | 256      | 8192     | 3072     | 8      | 32         | 0        | 226.6           | 14.5         | 56.860
True         | 256      | 8192     | 3072     | 4      | 128        | 0        | 206.7           | 39.9         | 62.333
True         | 256      | 8192     | 3072     | 8      | 128        | 0        | 216.2           | 41.3         | 59.587
True         | 256      | 200064   | 3072     | 4      | 32         | 0        | 3110.9          | 11.3         | 101.152
True         | 256      | 200064   | 3072     | 8      | 32         | 0        | 3290.9          | 8.3          | 95.619
True         | 256      | 200064   | 3072     | 4      | 128        | 0        | 3055.2          | 10.2         | 102.995
True         | 256      | 200064   | 3072     | 8      | 128        | 0        | 3220.4          | 9.8          | 97.712
True         | 1024     | 3072     | 8192     | 4      | 32         | 0        | 363.6           | 40.2         | 141.754
True         | 1024     | 3072     | 8192     | 8      | 32         | 0        | 369.0           | 46.0         | 139.669
True         | 1024     | 3072     | 8192     | 4      | 128        | 0        | 362.8           | 55.6         | 142.052
True         | 1024     | 3072     | 8192     | 8      | 128        | 0        | 367.5           | 56.5         | 140.256
True         | 1024     | 5120     | 3072     | 4      | 32         | 0        | 221.6           | 58.1         | 145.383
True         | 1024     | 5120     | 3072     | 8      | 32         | 0        | 225.4           | 56.6         | 142.938
True         | 1024     | 5120     | 3072     | 4      | 128        | 0        | 220.2           | 36.9         | 146.306
True         | 1024     | 5120     | 3072     | 8      | 128        | 0        | 224.1           | 57.8         | 143.751
True         | 1024     | 8192     | 3072     | 4      | 32         | 0        | 346.2           | 41.8         | 148.854
True         | 1024     | 8192     | 3072     | 8      | 32         | 0        | 352.8           | 21.6         | 146.097
True         | 1024     | 8192     | 3072     | 4      | 128        | 0        | 344.5           | 18.9         | 149.627
True         | 1024     | 8192     | 3072     | 8      | 128        | 0        | 350.6           | 10.6         | 147.016
True         | 1024     | 200064   | 3072     | 4      | 32         | 0        | 6822.0          | 44.1         | 184.504
True         | 1024     | 200064   | 3072     | 8      | 32         | 0        | 7018.5          | 38.4         | 179.339
True         | 1024     | 200064   | 3072     | 4      | 128        | 0        | 6757.8          | 51.5         | 186.257
True         | 1024     | 200064   | 3072     | 8      | 128        | 0        | 6947.7          | 38.1         | 181.167
------------------------------------------------------------------------------------------------------------------------
```
### Motivation and Context
Follow up with #24509
### Description

There is some build error for `--cmake_extra_defines
CMAKE_CUDA_ARCHITECTURES=52`.

Some half2 function like `__hfma2` used in MatMul 8 bits is not defined
for sm < 53. Add an implementation that does not use half2 for those old
GPUs.

Fix another build error using cuda 12.5 that is caused by extra `const`
in MOE code for sm<53.

### Motivation and Context

Fix nuget packaging pipeline, which uses
`CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual`.
### Description

Currently, flash attention is only enabled for sm8x and sm90. That means
blackwell GPU will not use flash attention. This change is enable flash
attention for sm > 90.

Note that the flash attention implementation is not optimized for
blackwell, but shall be able to run in blackwell GPU.

Future works:
* Integrate flash attn for hopper:
https://github.com/Dao-AILab/flash-attention/tree/main/hopper
* Integrate fmha for blackwell:
https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha
* Update cudnn and cudnn frontend to latest version (so that we can use
the cudnn flash attention for blackwell).

### Motivation and Context
ORT GENAI is slow in RTX 5090
…ull knowledge (#24568)

<!-- Describe your changes. -->
GetDeviceInfoIfSupported -> GetSupportedDevices

EP sees all devices so it can make decisions with full knowledge. This
is mainly applicable to GPU EPs like WebGPU.

EP has to iterate device and call CreateEpDevice for devices it
supports.
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
…#24587)

### Description
`LoadPluginOrProviderBridge` is called when attempting to load a Plugin.
It uses the passed `library_path` to attempt to load the Plugin as a
`Provider` - using `ProviderLibrary` - to see if it can be treated as a
'ProviderBridge'. `ProviderLibrary` attempts to load the Provider by
prefixing the path to the onnxruntime.dll. Plugins needn't be
redistributed with OnnxRuntime, so the path to the Plugin _may_ be an
absolute path, and if so `ProviderLibrary` fails. At the same time -
however - `LoadPluginOrProviderBridge` needs to support
OnnxRuntime-relative paths: As 'Providers' are migrated to 'Plugins',
existing Providers should be usable as Plugins. To accommodate both
scenarios, this PR:

1. Adds support to `ProviderLibrary` to be created with an absolute
path.
2. Validates the path passed to `LoadPluginOrProviderBridge`;
1. if it is absolute, the same absolute path is passed to
`ProviderLibrary` and `EpLibraryPlugin`.
2. if the path is not absolute, it is converted to an absolute path by
prefixing the OnnxRuntime location, and the same path is passed to
`ProviderLibrary` and `EpLibraryPlugin`.

### Motivation and Context
This PR enables `LoadPluginOrProviderBridge` to be called with an
absolute path to the Plugin, allowing it to be used as a
'ProviderBridge', or with an OnnxRuntime-relative path to the Plugin.

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
### Description
<!-- Describe your changes. -->

Add some logic to detect whether I8MM is actually supported.

This info can be read from the registry. See the helpful comments here
for more details:

https://github.com/Dr-Noob/cpufetch/blob/a0c08ccc0b64b524ad2122e0595099f73cbba9c4/src/arm/midr.c#L30-L52

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Detect I8MM correctly to enable better performance.

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Registered the ScatterND Op in QNN EP
- Created the op as part of the Simple Op Builder
- Added unit test to verify the Op runs on QNN
- Skipping ScatterND on QNN CPU (To Do)

### Description

Add ScatterND Op Support in QNN EP



### Motivation and Context

Performance improvement as ScatterND Op falls to ORT CPU due to missing
support
### Description
This PR incorporates the changes requested in PR: 24394

Changes are summarized below: 
1. Reordered enable_ovep_qdq_optimizer to appear before all output
parameters as per review suggestion. Other reorders are also done for
clarity.
2. Replaced non-release build check with RELEASE flag for clarity. This
will allow all build configs to dump model except release.
- Introduces `USE_<EP>_PROVIDER_INTERFACE` pre-processor macros that
indicate when an EP interface is enabled but the full EP is not being
compiled.
- Previously, the CMake configuration turned on `USE_<EP>` for both use
cases. This prevented tests from determining whether the full EP or only
the interface was available, which caused test failures. It also turned
on all EP code paths in core ORT code at the same time, which caused
compilation and logic errors.
- Adds the new NV EP to list of EPs whose interface is enabled with ORT
is built with `--enable_generic_interface`
- Updates the Windows Arm64 QNN CI Pipeline to actually use the
`--enable_generic_interface` flag.
- Previously, It was not actually being passed to the build command, so
no unit tests were being run with the flag enabled.
- Adds unit tests to check that adding an EP to the session options
fails when only the generic interface (but not the full EP) is built.
- Windows ARM64 QNN CI Pipeline:
- Builds ORT with `--use_qnn --enable_generic_interface` and runs all
normal QNN EP unit tests.
- Builds ORT with `--use_qnn --enable_generic_interface` and runs new
unit tests that try to add the following EPs to the session options
(expect failure): OpenVINO, CUDA, NV, TensorRT, VitisAI
- Build and Test OpenVINO EP (AlamLinux8, Py3.12) / build_test_pipeline:
- Builds ORT with `--use_openvino --enable_generic_interface` and runs
all normal OpenVINO EP unit tests.
- Builds ORT with `--use_openvino --enable_generic_interface` and runs
new unit tests that try to add the following EPs to the session options
(expect failure): QNN, CUDA, NV, TensorRT, VitisAI
- windows_x64_release_ep_generic_interface
- Builds ORT with `--enable_generic_interface` and now runs CPU EP unit
tests (didn't previously).
Fix use of `--enable_generic_interface` and make sure tests actually
run.
### Description
<!-- Describe your changes. -->
Update Qnn nuget package to use Arm64x binary.
Enable build with generic interface.
Copy Qnn libs with Qnn ep project build instead of the test_all project.
Update DML nuget package to enable generic interface, and pack the shared.dll into the package.
…lls (#24606)

### Description
Fixes #24500

- Fixes local build of onnxruntime.dll to have a valid version, such as
"1.23.0", instead of the literal string "ORT_VERSION"
- Adds version info to onnxruntime_providers_qnn.dll,
onnxruntime_providers_cuda.dll, and onnxruntime_providers_tensorrt.dll.
It was missing completely. This was done by adding
`onnxruntime_providers_*.rc` files to define each EP's [DLL version
info](https://learn.microsoft.com/en-us/windows/win32/menurc/versioninfo-resource).

Fixed onnxruntime.dll version info (local non-ADO build):
<img width="263" alt="image"
src="https://github.com/user-attachments/assets/33ef85ea-ac36-4c6a-9171-8fe4fb35955d"
/>

Fixed onnxruntime_providers_qnn.dll version info (adds QNN SDK version
too):
<img width="275" alt="image"
src="https://github.com/user-attachments/assets/a1f04604-2e3c-416d-989e-e92cb7df1776"
/>


### Motivation and Context
We create dlls with invalid or missing version info.
@vraspar vraspar requested a review from a team as a code owner May 1, 2025 07:29
@jywu-msft jywu-msft merged commit cf92d98 into rel-1.22.0 May 1, 2025
98 of 143 checks passed
@jywu-msft jywu-msft deleted the vraspar/rel1.22/cherry_picks_round2 branch May 1, 2025 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.