Skip to content

Conversation

@tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Apr 23, 2025

Description

Support 8 bits in MatMulNBits cuda kernel.

The MatMulFloat8bKernel CUDA kernel performs a matrix-vector multiplication (GEMM) where the matrix B is quantized per block using 8-bit integers.

The kernel computes $Output = A \times B$, where:

  • $A$ is a row vector (shape [M, K]) of type T (float or half).
  • $B$ is a matrix (shape [K, N]) quantized using 8-bit unsigned integers (uint8_t) with a block structure. It's stored as [N, K/block_size, block_size].
  • scales_data contains the dequantization scales (shape [N, K/block_size]).
  • zero_points contains the dequantization zero points (shape [N, K/block_size]), if used (has_zero_point is true).
  • output is the resulting row vector (shape [M, N]).

The kernel uses a thread block structure of (kWarpSize, kColsPerThreadBlock), meaning each block handles kColsPerThreadBlock (which is 8) columns of the output. Each warp within the block is responsible for one output element ([m_id, n_id]). Threads within a warp cooperate to compute the dot product along the K dimension. Each thread (lane_id) handles kElementsPerThreadPerIteration (which is 8) elements of the K dimension in each step.

Here's a breakdown of the three algorithms (kKernelAlgo):

  1. kKernelAlgo = 0 (Unrolling):

    • Strategy: This algorithm processes the K dimension by iterating in large steps (k_per_iter = kWarpSize * kElementsPerThreadPerIteration = 32 * 8 = 256). Inside the main loop, it uses a macro (UnRollReduction) with #pragma unroll directives to aggressively unroll the innermost computations. It tries unrolling factors of 16, 4, and 1 sequentially to cover as much of the K dimension as possible with unrolled code.
    • Pros: Can significantly reduce loop overhead (branching instructions, counter updates) and expose more instruction-level parallelism, potentially hiding memory latency.
    • Cons: Can lead to a large increase in compiled code size (register pressure, potential instruction cache misses). The effectiveness heavily depends on the compiler and the specific GPU architecture. The multi-stage unrolling adds complexity. It requires k_per_iter to be a multiple of block_size for correct scale/zp indexing within the unrolled loop.
    • Performance Expectation: Potentially the highest performance if the unrolling is effective on the target hardware and doesn't cause resource issues (registers, cache). Often good for compute-bound or latency-bound scenarios where loop overhead is a bottleneck.
  2. kKernelAlgo = 1 (Simple Loop):

    • Strategy: This algorithm also iterates along the K dimension in steps of k_per_iter (256), but uses a simple for loop without explicit #pragma unroll. It relies on the compiler's default loop optimization capabilities.
    • Pros: Simpler code, smaller code size compared to Algorithm 0. Less likely to cause register pressure or instruction cache issues. Easier for the compiler to reason about.
    • Cons: May incur higher loop overhead compared to effective unrolling. Performance might be lower if loop overhead is significant.
    • Performance Expectation: A solid baseline. Might be close to Algorithm 0 if the compiler performs implicit unrolling effectively, or faster if Algorithm 0 suffers from code bloat penalties.
  3. kKernelAlgo = 2 (Block Size Iteration):

    • Strategy: This algorithm changes the iteration strategy fundamentally. Instead of iterating in fixed steps of k_per_iter, it iterates based on the quantization block_size. The outer loop runs blocks_per_K (K / block_size) times. Inside this loop, the scale and zero point for the entire block are fetched once per warp. Then, each thread checks if its assigned K-elements (lane_offset) fall within the current block_size chunk and processes them using the fetched scale/zp.
    • Pros: Directly aligns with the block quantization data structure. Fetches scale/zero-point values less frequently (once per block_size chunk per warp), potentially reducing shared memory bank conflicts or register usage compared to calculating the index (current_meta_k) in every inner step as in Algo 0/1. Might have better memory access patterns for scale/zp data.
    • Cons: The outer loop iterates K / block_size times. If block_size is small (e.g., 16, 32), this could be many iterations. The logic inside the loop (if (current_k_base < k_end_block ...)) adds conditional execution.
    • Performance Expectation: Performance depends heavily on the block_size. If block_size is large (e.g., 128, 256), the number of outer loop iterations is small, and the efficiency gain from fetching scale/zp once per block might outweigh the overhead. If block_size is small, the overhead of the outer loop might dominate.

Next Step:

  1. Profile: The most reliable way is to benchmark all three algorithms (kKernelAlgo = 0, 1, 2) on your target GPU hardware with representative input sizes (N, K), data types (T), and block_size values. Use profiling tools like NVIDIA Nsight Compute to analyze performance metrics (execution time, occupancy, instruction throughput, memory bandwidth, cache hit rates, register spills).
  2. Hypothesize based on block_size:
    • For large block_size (e.g., 128, 256), Algorithm 2 might be competitive or even the best due to efficient scale/ZP handling. Algorithm 0 could also be very fast.
    • For small block_size (e.g., 16, 32), Algorithm 0 (unroll) or Algorithm 1 (simple loop) might outperform Algorithm 2 due to lower loop overhead in the K dimension.
  3. Compare performance with TRT LLM FpA IntB GEMM.

Motivation and Context

4 bits has accuracy loss for some LLM, need more bits for some layers.

@tianleiwu tianleiwu marked this pull request as draft April 23, 2025 00:02
@tianleiwu tianleiwu marked this pull request as ready for review April 23, 2025 23:10
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

jiafatom
jiafatom previously approved these changes Apr 25, 2025
@tianleiwu tianleiwu merged commit 3a7c8b3 into main Apr 25, 2025
86 of 88 checks passed
@tianleiwu tianleiwu deleted the tlwu/cuda_matmul_8bits branch April 25, 2025 17:06
snnn pushed a commit that referenced this pull request Apr 26, 2025
### Description
1. Add benchmark script for MatMulNBits. 
2. Update kernel based on benchmark results:
  - Change kernel back to handle m=1
  - Use simple loop kernel instead of unrolling
- Change partial sum to float type to trade-off precision and
performance (less precision loss, no obvious performance drop)

Example output of benchmark:
```
------------------------------------------------------------------------------------------------------------------------
Benchmarking MatMulNBits on NVIDIA A100-SXM4-80GB (Compute Capability: 8.0)
------------------------------------------------------------------------------------------------------------------------
CUDA Graph   | M        | N        | K        | Bits   | Block Size | Threads  | Latency (us)    | StdDev (us)  | TFLOPS
------------------------------------------------------------------------------------------------------------------------
True         | 1        | 3072     | 8192     | 4      | 32         | 0        | 95.7            | 5.7          | 0.526
True         | 1        | 3072     | 8192     | 8      | 32         | 0        | 110.7           | 81.1         | 0.454
True         | 1        | 3072     | 8192     | 4      | 128        | 0        | 93.7            | 41.2         | 0.537
True         | 1        | 3072     | 8192     | 8      | 128        | 0        | 105.0           | 129.3        | 0.479
True         | 1        | 5120     | 3072     | 4      | 32         | 0        | 86.7            | 49.9         | 0.363
True         | 1        | 5120     | 3072     | 8      | 32         | 0        | 90.1            | 41.1         | 0.349
True         | 1        | 5120     | 3072     | 4      | 128        | 0        | 83.9            | 46.7         | 0.375
True         | 1        | 5120     | 3072     | 8      | 128        | 0        | 85.2            | 57.1         | 0.369
True         | 1        | 8192     | 3072     | 4      | 32         | 0        | 107.3           | 29.2         | 0.469
True         | 1        | 8192     | 3072     | 8      | 32         | 0        | 102.3           | 57.1         | 0.492
True         | 1        | 8192     | 3072     | 4      | 128        | 0        | 99.2            | 61.2         | 0.507
True         | 1        | 8192     | 3072     | 8      | 128        | 0        | 97.5            | 47.4         | 0.516
True         | 1        | 200064   | 3072     | 4      | 32         | 0        | 1456.4          | 11.0         | 0.844
True         | 1        | 200064   | 3072     | 8      | 32         | 0        | 1336.4          | 10.3         | 0.920
True         | 1        | 200064   | 3072     | 4      | 128        | 0        | 1261.6          | 16.6         | 0.974
True         | 1        | 200064   | 3072     | 8      | 128        | 0        | 1232.6          | 17.9         | 0.997
True         | 256      | 3072     | 8192     | 4      | 32         | 0        | 211.1           | 5.8          | 61.030
True         | 256      | 3072     | 8192     | 8      | 32         | 0        | 217.8           | 62.8         | 59.154
True         | 256      | 3072     | 8192     | 4      | 128        | 0        | 208.7           | 63.3         | 61.751
True         | 256      | 3072     | 8192     | 8      | 128        | 0        | 213.0           | 58.2         | 60.491
True         | 256      | 5120     | 3072     | 4      | 32         | 0        | 151.9           | 57.4         | 53.028
True         | 256      | 5120     | 3072     | 8      | 32         | 0        | 156.2           | 71.1         | 51.554
True         | 256      | 5120     | 3072     | 4      | 128        | 0        | 151.4           | 22.6         | 53.198
True         | 256      | 5120     | 3072     | 8      | 128        | 0        | 154.6           | 47.1         | 52.092
True         | 256      | 8192     | 3072     | 4      | 32         | 0        | 219.0           | 4.4          | 58.847
True         | 256      | 8192     | 3072     | 8      | 32         | 0        | 226.6           | 14.5         | 56.860
True         | 256      | 8192     | 3072     | 4      | 128        | 0        | 206.7           | 39.9         | 62.333
True         | 256      | 8192     | 3072     | 8      | 128        | 0        | 216.2           | 41.3         | 59.587
True         | 256      | 200064   | 3072     | 4      | 32         | 0        | 3110.9          | 11.3         | 101.152
True         | 256      | 200064   | 3072     | 8      | 32         | 0        | 3290.9          | 8.3          | 95.619
True         | 256      | 200064   | 3072     | 4      | 128        | 0        | 3055.2          | 10.2         | 102.995
True         | 256      | 200064   | 3072     | 8      | 128        | 0        | 3220.4          | 9.8          | 97.712
True         | 1024     | 3072     | 8192     | 4      | 32         | 0        | 363.6           | 40.2         | 141.754
True         | 1024     | 3072     | 8192     | 8      | 32         | 0        | 369.0           | 46.0         | 139.669
True         | 1024     | 3072     | 8192     | 4      | 128        | 0        | 362.8           | 55.6         | 142.052
True         | 1024     | 3072     | 8192     | 8      | 128        | 0        | 367.5           | 56.5         | 140.256
True         | 1024     | 5120     | 3072     | 4      | 32         | 0        | 221.6           | 58.1         | 145.383
True         | 1024     | 5120     | 3072     | 8      | 32         | 0        | 225.4           | 56.6         | 142.938
True         | 1024     | 5120     | 3072     | 4      | 128        | 0        | 220.2           | 36.9         | 146.306
True         | 1024     | 5120     | 3072     | 8      | 128        | 0        | 224.1           | 57.8         | 143.751
True         | 1024     | 8192     | 3072     | 4      | 32         | 0        | 346.2           | 41.8         | 148.854
True         | 1024     | 8192     | 3072     | 8      | 32         | 0        | 352.8           | 21.6         | 146.097
True         | 1024     | 8192     | 3072     | 4      | 128        | 0        | 344.5           | 18.9         | 149.627
True         | 1024     | 8192     | 3072     | 8      | 128        | 0        | 350.6           | 10.6         | 147.016
True         | 1024     | 200064   | 3072     | 4      | 32         | 0        | 6822.0          | 44.1         | 184.504
True         | 1024     | 200064   | 3072     | 8      | 32         | 0        | 7018.5          | 38.4         | 179.339
True         | 1024     | 200064   | 3072     | 4      | 128        | 0        | 6757.8          | 51.5         | 186.257
True         | 1024     | 200064   | 3072     | 8      | 128        | 0        | 6947.7          | 38.1         | 181.167
------------------------------------------------------------------------------------------------------------------------
```
### Motivation and Context
Follow up with #24509
jywu-msft pushed a commit that referenced this pull request Apr 30, 2025
### Description

Cherry pick the following into
[rel-1.22.0](https://github.com/microsoft/onnxruntime/tree/rel-1.22.0)


- (#24487)
- (#24466)
- (#24493)
- (#24484)
- (#24494)
- (#24489)
- (#24504)
- (#24510)
- (#24456)
- (#24537)
- (#24501)
- (#24519)
- (#24513)
- (#24539)
- (#24514)
- (#24542)
- (#24585)

Not added:

Planning to cherry pick Cuda Matmulnbits PRs once the fix for failing
cuda pipeline is ready
- (#24491)
- (#24509)
- (#24564)

---------

Co-authored-by: Adrian Lizarraga <[email protected]>
Co-authored-by: minfhong-quic <[email protected]>
Co-authored-by: minfhong-quic <[email protected]>
Co-authored-by: Justin Chu <[email protected]>
Co-authored-by: Prathik Rao <[email protected]>
Co-authored-by: Edward Chen <[email protected]>
Co-authored-by: Ankan Banerjee <[email protected]>
Co-authored-by: Maximilian Müller <[email protected]>
Co-authored-by: Gaurav Garg <[email protected]>
Co-authored-by: iraut <[email protected]>
Co-authored-by: Hrishikesh Manohar <[email protected]>
Co-authored-by: Maximilian Müller <[email protected]>
Co-authored-by: Scott McKay <[email protected]>
Co-authored-by: Jiajia Qin <[email protected]>
Co-authored-by: kunal-vaishnavi <[email protected]>
Co-authored-by: xhcao <[email protected]>
jatinwadhwa921 pushed a commit to intel/onnxruntime that referenced this pull request Apr 30, 2025
### Description

Cherry pick the following into
[rel-1.22.0](https://github.com/microsoft/onnxruntime/tree/rel-1.22.0)


- (microsoft#24487)
- (microsoft#24466)
- (microsoft#24493)
- (microsoft#24484)
- (microsoft#24494)
- (microsoft#24489)
- (microsoft#24504)
- (microsoft#24510)
- (microsoft#24456)
- (microsoft#24537)
- (microsoft#24501)
- (microsoft#24519)
- (microsoft#24513)
- (microsoft#24539)
- (microsoft#24514)
- (microsoft#24542)
- (microsoft#24585)

Not added:

Planning to cherry pick Cuda Matmulnbits PRs once the fix for failing
cuda pipeline is ready
- (microsoft#24491)
- (microsoft#24509)
- (microsoft#24564)

---------

Co-authored-by: vraspar <[email protected]>
Co-authored-by: Adrian Lizarraga <[email protected]>
Co-authored-by: minfhong-quic <[email protected]>
Co-authored-by: minfhong-quic <[email protected]>
Co-authored-by: Justin Chu <[email protected]>
Co-authored-by: Prathik Rao <[email protected]>
Co-authored-by: Edward Chen <[email protected]>
Co-authored-by: Ankan Banerjee <[email protected]>
Co-authored-by: Maximilian Müller <[email protected]>
Co-authored-by: Gaurav Garg <[email protected]>
Co-authored-by: iraut <[email protected]>
Co-authored-by: Hrishikesh Manohar <[email protected]>
Co-authored-by: Maximilian Müller <[email protected]>
Co-authored-by: Scott McKay <[email protected]>
Co-authored-by: Jiajia Qin <[email protected]>
Co-authored-by: kunal-vaishnavi <[email protected]>
Co-authored-by: xhcao <[email protected]>
vraspar pushed a commit that referenced this pull request May 1, 2025
### Description
Support 8 bits in MatMulNBits cuda kernel.

The `MatMulFloat8bKernel` CUDA kernel performs a matrix-vector
multiplication (GEMM) where the matrix B is quantized per block using
8-bit integers.

The kernel computes $Output = A \times B$, where:
* $A$ is a row vector (shape `[M, K]`) of type `T` (`float` or `half`).
* $B$ is a matrix (shape `[K, N]`) quantized using 8-bit unsigned
integers (`uint8_t`) with a block structure. It's stored as `[N,
K/block_size, block_size]`.
* `scales_data` contains the dequantization scales (shape `[N,
K/block_size]`).
* `zero_points` contains the dequantization zero points (shape `[N,
K/block_size]`), if used (`has_zero_point` is true).
* `output` is the resulting row vector (shape `[M, N]`).

The kernel uses a thread block structure of `(kWarpSize,
kColsPerThreadBlock)`, meaning each block handles `kColsPerThreadBlock`
(which is 8) columns of the output. Each warp within the block is
responsible for one output element (`[m_id, n_id]`). Threads within a
warp cooperate to compute the dot product along the K dimension. Each
thread (`lane_id`) handles `kElementsPerThreadPerIteration` (which is 8)
elements of the K dimension in each step.

Here's a breakdown of the three algorithms (`kKernelAlgo`):

1.  **`kKernelAlgo = 0` (Unrolling):**
* **Strategy:** This algorithm processes the K dimension by iterating in
large steps (`k_per_iter = kWarpSize * kElementsPerThreadPerIteration =
32 * 8 = 256`). Inside the main loop, it uses a macro
(`UnRollReduction`) with `#pragma unroll` directives to aggressively
unroll the innermost computations. It tries unrolling factors of 16, 4,
and 1 sequentially to cover as much of the K dimension as possible with
unrolled code.
* **Pros:** Can significantly reduce loop overhead (branching
instructions, counter updates) and expose more instruction-level
parallelism, potentially hiding memory latency.
* **Cons:** Can lead to a large increase in compiled code size (register
pressure, potential instruction cache misses). The effectiveness heavily
depends on the compiler and the specific GPU architecture. The
multi-stage unrolling adds complexity. It requires `k_per_iter` to be a
multiple of `block_size` for correct scale/zp indexing within the
unrolled loop.
* **Performance Expectation:** Potentially the highest performance *if*
the unrolling is effective on the target hardware and doesn't cause
resource issues (registers, cache). Often good for compute-bound or
latency-bound scenarios where loop overhead is a bottleneck.

2.  **`kKernelAlgo = 1` (Simple Loop):**
* **Strategy:** This algorithm also iterates along the K dimension in
steps of `k_per_iter` (256), but uses a simple `for` loop without
explicit `#pragma unroll`. It relies on the compiler's default loop
optimization capabilities.
* **Pros:** Simpler code, smaller code size compared to Algorithm 0.
Less likely to cause register pressure or instruction cache issues.
Easier for the compiler to reason about.
* **Cons:** May incur higher loop overhead compared to effective
unrolling. Performance might be lower if loop overhead is significant.
* **Performance Expectation:** A solid baseline. Might be close to
Algorithm 0 if the compiler performs implicit unrolling effectively, or
faster if Algorithm 0 suffers from code bloat penalties.

3.  **`kKernelAlgo = 2` (Block Size Iteration):**
* **Strategy:** This algorithm changes the iteration strategy
fundamentally. Instead of iterating in fixed steps of `k_per_iter`, it
iterates based on the quantization `block_size`. The outer loop runs
`blocks_per_K` (`K / block_size`) times. Inside this loop, the scale and
zero point for the *entire block* are fetched once per warp. Then, each
thread checks if its assigned K-elements (`lane_offset`) fall within the
current `block_size` chunk and processes them using the fetched
scale/zp.
* **Pros:** Directly aligns with the block quantization data structure.
Fetches scale/zero-point values less frequently (once per `block_size`
chunk per warp), potentially reducing shared memory bank conflicts or
register usage compared to calculating the index (`current_meta_k`) in
every inner step as in Algo 0/1. Might have better memory access
patterns for scale/zp data.
* **Cons:** The outer loop iterates `K / block_size` times. If
`block_size` is small (e.g., 16, 32), this could be many iterations. The
logic inside the loop (`if (current_k_base < k_end_block ...)`) adds
conditional execution.
* **Performance Expectation:** Performance depends heavily on the
`block_size`. If `block_size` is large (e.g., 128, 256), the number of
outer loop iterations is small, and the efficiency gain from fetching
scale/zp once per block might outweigh the overhead. If `block_size` is
small, the overhead of the outer loop might dominate.

**Next Step:**

1. **Profile:** The most reliable way is to benchmark all three
algorithms (`kKernelAlgo = 0, 1, 2`) on your target GPU hardware with
representative input sizes (`N`, `K`), data types (`T`), and
`block_size` values. Use profiling tools like NVIDIA Nsight Compute to
analyze performance metrics (execution time, occupancy, instruction
throughput, memory bandwidth, cache hit rates, register spills).
2.  **Hypothesize based on `block_size`:**
* For **large `block_size`** (e.g., 128, 256), Algorithm 2 might be
competitive or even the best due to efficient scale/ZP handling.
Algorithm 0 could also be very fast.
* For **small `block_size`** (e.g., 16, 32), Algorithm 0 (unroll) or
Algorithm 1 (simple loop) might outperform Algorithm 2 due to lower loop
overhead in the K dimension.
3. Compare performance with TRT LLM FpA IntB GEMM.

### Motivation and Context
4 bits has accuracy loss for some LLM, need more bits for some layers.
vraspar pushed a commit that referenced this pull request May 1, 2025
### Description
1. Add benchmark script for MatMulNBits. 
2. Update kernel based on benchmark results:
  - Change kernel back to handle m=1
  - Use simple loop kernel instead of unrolling
- Change partial sum to float type to trade-off precision and
performance (less precision loss, no obvious performance drop)

Example output of benchmark:
```
------------------------------------------------------------------------------------------------------------------------
Benchmarking MatMulNBits on NVIDIA A100-SXM4-80GB (Compute Capability: 8.0)
------------------------------------------------------------------------------------------------------------------------
CUDA Graph   | M        | N        | K        | Bits   | Block Size | Threads  | Latency (us)    | StdDev (us)  | TFLOPS
------------------------------------------------------------------------------------------------------------------------
True         | 1        | 3072     | 8192     | 4      | 32         | 0        | 95.7            | 5.7          | 0.526
True         | 1        | 3072     | 8192     | 8      | 32         | 0        | 110.7           | 81.1         | 0.454
True         | 1        | 3072     | 8192     | 4      | 128        | 0        | 93.7            | 41.2         | 0.537
True         | 1        | 3072     | 8192     | 8      | 128        | 0        | 105.0           | 129.3        | 0.479
True         | 1        | 5120     | 3072     | 4      | 32         | 0        | 86.7            | 49.9         | 0.363
True         | 1        | 5120     | 3072     | 8      | 32         | 0        | 90.1            | 41.1         | 0.349
True         | 1        | 5120     | 3072     | 4      | 128        | 0        | 83.9            | 46.7         | 0.375
True         | 1        | 5120     | 3072     | 8      | 128        | 0        | 85.2            | 57.1         | 0.369
True         | 1        | 8192     | 3072     | 4      | 32         | 0        | 107.3           | 29.2         | 0.469
True         | 1        | 8192     | 3072     | 8      | 32         | 0        | 102.3           | 57.1         | 0.492
True         | 1        | 8192     | 3072     | 4      | 128        | 0        | 99.2            | 61.2         | 0.507
True         | 1        | 8192     | 3072     | 8      | 128        | 0        | 97.5            | 47.4         | 0.516
True         | 1        | 200064   | 3072     | 4      | 32         | 0        | 1456.4          | 11.0         | 0.844
True         | 1        | 200064   | 3072     | 8      | 32         | 0        | 1336.4          | 10.3         | 0.920
True         | 1        | 200064   | 3072     | 4      | 128        | 0        | 1261.6          | 16.6         | 0.974
True         | 1        | 200064   | 3072     | 8      | 128        | 0        | 1232.6          | 17.9         | 0.997
True         | 256      | 3072     | 8192     | 4      | 32         | 0        | 211.1           | 5.8          | 61.030
True         | 256      | 3072     | 8192     | 8      | 32         | 0        | 217.8           | 62.8         | 59.154
True         | 256      | 3072     | 8192     | 4      | 128        | 0        | 208.7           | 63.3         | 61.751
True         | 256      | 3072     | 8192     | 8      | 128        | 0        | 213.0           | 58.2         | 60.491
True         | 256      | 5120     | 3072     | 4      | 32         | 0        | 151.9           | 57.4         | 53.028
True         | 256      | 5120     | 3072     | 8      | 32         | 0        | 156.2           | 71.1         | 51.554
True         | 256      | 5120     | 3072     | 4      | 128        | 0        | 151.4           | 22.6         | 53.198
True         | 256      | 5120     | 3072     | 8      | 128        | 0        | 154.6           | 47.1         | 52.092
True         | 256      | 8192     | 3072     | 4      | 32         | 0        | 219.0           | 4.4          | 58.847
True         | 256      | 8192     | 3072     | 8      | 32         | 0        | 226.6           | 14.5         | 56.860
True         | 256      | 8192     | 3072     | 4      | 128        | 0        | 206.7           | 39.9         | 62.333
True         | 256      | 8192     | 3072     | 8      | 128        | 0        | 216.2           | 41.3         | 59.587
True         | 256      | 200064   | 3072     | 4      | 32         | 0        | 3110.9          | 11.3         | 101.152
True         | 256      | 200064   | 3072     | 8      | 32         | 0        | 3290.9          | 8.3          | 95.619
True         | 256      | 200064   | 3072     | 4      | 128        | 0        | 3055.2          | 10.2         | 102.995
True         | 256      | 200064   | 3072     | 8      | 128        | 0        | 3220.4          | 9.8          | 97.712
True         | 1024     | 3072     | 8192     | 4      | 32         | 0        | 363.6           | 40.2         | 141.754
True         | 1024     | 3072     | 8192     | 8      | 32         | 0        | 369.0           | 46.0         | 139.669
True         | 1024     | 3072     | 8192     | 4      | 128        | 0        | 362.8           | 55.6         | 142.052
True         | 1024     | 3072     | 8192     | 8      | 128        | 0        | 367.5           | 56.5         | 140.256
True         | 1024     | 5120     | 3072     | 4      | 32         | 0        | 221.6           | 58.1         | 145.383
True         | 1024     | 5120     | 3072     | 8      | 32         | 0        | 225.4           | 56.6         | 142.938
True         | 1024     | 5120     | 3072     | 4      | 128        | 0        | 220.2           | 36.9         | 146.306
True         | 1024     | 5120     | 3072     | 8      | 128        | 0        | 224.1           | 57.8         | 143.751
True         | 1024     | 8192     | 3072     | 4      | 32         | 0        | 346.2           | 41.8         | 148.854
True         | 1024     | 8192     | 3072     | 8      | 32         | 0        | 352.8           | 21.6         | 146.097
True         | 1024     | 8192     | 3072     | 4      | 128        | 0        | 344.5           | 18.9         | 149.627
True         | 1024     | 8192     | 3072     | 8      | 128        | 0        | 350.6           | 10.6         | 147.016
True         | 1024     | 200064   | 3072     | 4      | 32         | 0        | 6822.0          | 44.1         | 184.504
True         | 1024     | 200064   | 3072     | 8      | 32         | 0        | 7018.5          | 38.4         | 179.339
True         | 1024     | 200064   | 3072     | 4      | 128        | 0        | 6757.8          | 51.5         | 186.257
True         | 1024     | 200064   | 3072     | 8      | 128        | 0        | 6947.7          | 38.1         | 181.167
------------------------------------------------------------------------------------------------------------------------
```
### Motivation and Context
Follow up with #24509
jywu-msft pushed a commit that referenced this pull request May 1, 2025
### Description

Cherry pick the following into
[rel-1.22.0](https://github.com/microsoft/onnxruntime/tree/rel-1.22.0)

- (#24491)
- (#24509)
- (#24564)
- (#24574)
- (#24582)
- (#24584)
- (#24568)
- (#24587)
- (#24563)
- (#24592)
- (#24526)
- (#24552)
- (#24588)
- (#24605)
- (#24606)

---------

Co-authored-by: Jing Fang <[email protected]>
Co-authored-by: Tianlei Wu <[email protected]>
Co-authored-by: Baiju Meswani <[email protected]>
Co-authored-by: Scott McKay <[email protected]>
Co-authored-by: Mark Schofield <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Edward Chen <[email protected]>
Co-authored-by: Ashwath Shankarnarayan <[email protected]>
Co-authored-by: saurabh <[email protected]>
Co-authored-by: Adrian Lizarraga <[email protected]>
Co-authored-by: Hector Li <[email protected]>
ankitm3k pushed a commit to intel/onnxruntime that referenced this pull request May 12, 2025
### Description
Support 8 bits in MatMulNBits cuda kernel.

The `MatMulFloat8bKernel` CUDA kernel performs a matrix-vector
multiplication (GEMM) where the matrix B is quantized per block using
8-bit integers.

The kernel computes $Output = A \times B$, where:
* $A$ is a row vector (shape `[M, K]`) of type `T` (`float` or `half`).
* $B$ is a matrix (shape `[K, N]`) quantized using 8-bit unsigned
integers (`uint8_t`) with a block structure. It's stored as `[N,
K/block_size, block_size]`.
* `scales_data` contains the dequantization scales (shape `[N,
K/block_size]`).
* `zero_points` contains the dequantization zero points (shape `[N,
K/block_size]`), if used (`has_zero_point` is true).
* `output` is the resulting row vector (shape `[M, N]`).

The kernel uses a thread block structure of `(kWarpSize,
kColsPerThreadBlock)`, meaning each block handles `kColsPerThreadBlock`
(which is 8) columns of the output. Each warp within the block is
responsible for one output element (`[m_id, n_id]`). Threads within a
warp cooperate to compute the dot product along the K dimension. Each
thread (`lane_id`) handles `kElementsPerThreadPerIteration` (which is 8)
elements of the K dimension in each step.

Here's a breakdown of the three algorithms (`kKernelAlgo`):

1.  **`kKernelAlgo = 0` (Unrolling):**
* **Strategy:** This algorithm processes the K dimension by iterating in
large steps (`k_per_iter = kWarpSize * kElementsPerThreadPerIteration =
32 * 8 = 256`). Inside the main loop, it uses a macro
(`UnRollReduction`) with `#pragma unroll` directives to aggressively
unroll the innermost computations. It tries unrolling factors of 16, 4,
and 1 sequentially to cover as much of the K dimension as possible with
unrolled code.
* **Pros:** Can significantly reduce loop overhead (branching
instructions, counter updates) and expose more instruction-level
parallelism, potentially hiding memory latency.
* **Cons:** Can lead to a large increase in compiled code size (register
pressure, potential instruction cache misses). The effectiveness heavily
depends on the compiler and the specific GPU architecture. The
multi-stage unrolling adds complexity. It requires `k_per_iter` to be a
multiple of `block_size` for correct scale/zp indexing within the
unrolled loop.
* **Performance Expectation:** Potentially the highest performance *if*
the unrolling is effective on the target hardware and doesn't cause
resource issues (registers, cache). Often good for compute-bound or
latency-bound scenarios where loop overhead is a bottleneck.

2.  **`kKernelAlgo = 1` (Simple Loop):**
* **Strategy:** This algorithm also iterates along the K dimension in
steps of `k_per_iter` (256), but uses a simple `for` loop without
explicit `#pragma unroll`. It relies on the compiler's default loop
optimization capabilities.
* **Pros:** Simpler code, smaller code size compared to Algorithm 0.
Less likely to cause register pressure or instruction cache issues.
Easier for the compiler to reason about.
* **Cons:** May incur higher loop overhead compared to effective
unrolling. Performance might be lower if loop overhead is significant.
* **Performance Expectation:** A solid baseline. Might be close to
Algorithm 0 if the compiler performs implicit unrolling effectively, or
faster if Algorithm 0 suffers from code bloat penalties.

3.  **`kKernelAlgo = 2` (Block Size Iteration):**
* **Strategy:** This algorithm changes the iteration strategy
fundamentally. Instead of iterating in fixed steps of `k_per_iter`, it
iterates based on the quantization `block_size`. The outer loop runs
`blocks_per_K` (`K / block_size`) times. Inside this loop, the scale and
zero point for the *entire block* are fetched once per warp. Then, each
thread checks if its assigned K-elements (`lane_offset`) fall within the
current `block_size` chunk and processes them using the fetched
scale/zp.
* **Pros:** Directly aligns with the block quantization data structure.
Fetches scale/zero-point values less frequently (once per `block_size`
chunk per warp), potentially reducing shared memory bank conflicts or
register usage compared to calculating the index (`current_meta_k`) in
every inner step as in Algo 0/1. Might have better memory access
patterns for scale/zp data.
* **Cons:** The outer loop iterates `K / block_size` times. If
`block_size` is small (e.g., 16, 32), this could be many iterations. The
logic inside the loop (`if (current_k_base < k_end_block ...)`) adds
conditional execution.
* **Performance Expectation:** Performance depends heavily on the
`block_size`. If `block_size` is large (e.g., 128, 256), the number of
outer loop iterations is small, and the efficiency gain from fetching
scale/zp once per block might outweigh the overhead. If `block_size` is
small, the overhead of the outer loop might dominate.

**Next Step:**

1. **Profile:** The most reliable way is to benchmark all three
algorithms (`kKernelAlgo = 0, 1, 2`) on your target GPU hardware with
representative input sizes (`N`, `K`), data types (`T`), and
`block_size` values. Use profiling tools like NVIDIA Nsight Compute to
analyze performance metrics (execution time, occupancy, instruction
throughput, memory bandwidth, cache hit rates, register spills).
2.  **Hypothesize based on `block_size`:**
* For **large `block_size`** (e.g., 128, 256), Algorithm 2 might be
competitive or even the best due to efficient scale/ZP handling.
Algorithm 0 could also be very fast.
* For **small `block_size`** (e.g., 16, 32), Algorithm 0 (unroll) or
Algorithm 1 (simple loop) might outperform Algorithm 2 due to lower loop
overhead in the K dimension.
3. Compare performance with TRT LLM FpA IntB GEMM.

### Motivation and Context
4 bits has accuracy loss for some LLM, need more bits for some layers.
ankitm3k pushed a commit to intel/onnxruntime that referenced this pull request May 12, 2025
### Description
1. Add benchmark script for MatMulNBits. 
2. Update kernel based on benchmark results:
  - Change kernel back to handle m=1
  - Use simple loop kernel instead of unrolling
- Change partial sum to float type to trade-off precision and
performance (less precision loss, no obvious performance drop)

Example output of benchmark:
```
------------------------------------------------------------------------------------------------------------------------
Benchmarking MatMulNBits on NVIDIA A100-SXM4-80GB (Compute Capability: 8.0)
------------------------------------------------------------------------------------------------------------------------
CUDA Graph   | M        | N        | K        | Bits   | Block Size | Threads  | Latency (us)    | StdDev (us)  | TFLOPS
------------------------------------------------------------------------------------------------------------------------
True         | 1        | 3072     | 8192     | 4      | 32         | 0        | 95.7            | 5.7          | 0.526
True         | 1        | 3072     | 8192     | 8      | 32         | 0        | 110.7           | 81.1         | 0.454
True         | 1        | 3072     | 8192     | 4      | 128        | 0        | 93.7            | 41.2         | 0.537
True         | 1        | 3072     | 8192     | 8      | 128        | 0        | 105.0           | 129.3        | 0.479
True         | 1        | 5120     | 3072     | 4      | 32         | 0        | 86.7            | 49.9         | 0.363
True         | 1        | 5120     | 3072     | 8      | 32         | 0        | 90.1            | 41.1         | 0.349
True         | 1        | 5120     | 3072     | 4      | 128        | 0        | 83.9            | 46.7         | 0.375
True         | 1        | 5120     | 3072     | 8      | 128        | 0        | 85.2            | 57.1         | 0.369
True         | 1        | 8192     | 3072     | 4      | 32         | 0        | 107.3           | 29.2         | 0.469
True         | 1        | 8192     | 3072     | 8      | 32         | 0        | 102.3           | 57.1         | 0.492
True         | 1        | 8192     | 3072     | 4      | 128        | 0        | 99.2            | 61.2         | 0.507
True         | 1        | 8192     | 3072     | 8      | 128        | 0        | 97.5            | 47.4         | 0.516
True         | 1        | 200064   | 3072     | 4      | 32         | 0        | 1456.4          | 11.0         | 0.844
True         | 1        | 200064   | 3072     | 8      | 32         | 0        | 1336.4          | 10.3         | 0.920
True         | 1        | 200064   | 3072     | 4      | 128        | 0        | 1261.6          | 16.6         | 0.974
True         | 1        | 200064   | 3072     | 8      | 128        | 0        | 1232.6          | 17.9         | 0.997
True         | 256      | 3072     | 8192     | 4      | 32         | 0        | 211.1           | 5.8          | 61.030
True         | 256      | 3072     | 8192     | 8      | 32         | 0        | 217.8           | 62.8         | 59.154
True         | 256      | 3072     | 8192     | 4      | 128        | 0        | 208.7           | 63.3         | 61.751
True         | 256      | 3072     | 8192     | 8      | 128        | 0        | 213.0           | 58.2         | 60.491
True         | 256      | 5120     | 3072     | 4      | 32         | 0        | 151.9           | 57.4         | 53.028
True         | 256      | 5120     | 3072     | 8      | 32         | 0        | 156.2           | 71.1         | 51.554
True         | 256      | 5120     | 3072     | 4      | 128        | 0        | 151.4           | 22.6         | 53.198
True         | 256      | 5120     | 3072     | 8      | 128        | 0        | 154.6           | 47.1         | 52.092
True         | 256      | 8192     | 3072     | 4      | 32         | 0        | 219.0           | 4.4          | 58.847
True         | 256      | 8192     | 3072     | 8      | 32         | 0        | 226.6           | 14.5         | 56.860
True         | 256      | 8192     | 3072     | 4      | 128        | 0        | 206.7           | 39.9         | 62.333
True         | 256      | 8192     | 3072     | 8      | 128        | 0        | 216.2           | 41.3         | 59.587
True         | 256      | 200064   | 3072     | 4      | 32         | 0        | 3110.9          | 11.3         | 101.152
True         | 256      | 200064   | 3072     | 8      | 32         | 0        | 3290.9          | 8.3          | 95.619
True         | 256      | 200064   | 3072     | 4      | 128        | 0        | 3055.2          | 10.2         | 102.995
True         | 256      | 200064   | 3072     | 8      | 128        | 0        | 3220.4          | 9.8          | 97.712
True         | 1024     | 3072     | 8192     | 4      | 32         | 0        | 363.6           | 40.2         | 141.754
True         | 1024     | 3072     | 8192     | 8      | 32         | 0        | 369.0           | 46.0         | 139.669
True         | 1024     | 3072     | 8192     | 4      | 128        | 0        | 362.8           | 55.6         | 142.052
True         | 1024     | 3072     | 8192     | 8      | 128        | 0        | 367.5           | 56.5         | 140.256
True         | 1024     | 5120     | 3072     | 4      | 32         | 0        | 221.6           | 58.1         | 145.383
True         | 1024     | 5120     | 3072     | 8      | 32         | 0        | 225.4           | 56.6         | 142.938
True         | 1024     | 5120     | 3072     | 4      | 128        | 0        | 220.2           | 36.9         | 146.306
True         | 1024     | 5120     | 3072     | 8      | 128        | 0        | 224.1           | 57.8         | 143.751
True         | 1024     | 8192     | 3072     | 4      | 32         | 0        | 346.2           | 41.8         | 148.854
True         | 1024     | 8192     | 3072     | 8      | 32         | 0        | 352.8           | 21.6         | 146.097
True         | 1024     | 8192     | 3072     | 4      | 128        | 0        | 344.5           | 18.9         | 149.627
True         | 1024     | 8192     | 3072     | 8      | 128        | 0        | 350.6           | 10.6         | 147.016
True         | 1024     | 200064   | 3072     | 4      | 32         | 0        | 6822.0          | 44.1         | 184.504
True         | 1024     | 200064   | 3072     | 8      | 32         | 0        | 7018.5          | 38.4         | 179.339
True         | 1024     | 200064   | 3072     | 4      | 128        | 0        | 6757.8          | 51.5         | 186.257
True         | 1024     | 200064   | 3072     | 8      | 128        | 0        | 6947.7          | 38.1         | 181.167
------------------------------------------------------------------------------------------------------------------------
```
### Motivation and Context
Follow up with microsoft#24509
hariharans29 added a commit that referenced this pull request Jun 4, 2025
### Description

It seems like #24509 added
a guard for the 8 bit Matmul tests that depends on an MLAS macro being
set to compile and run on CPUs but that guard itself was preventing the
inclusion of the MLAS header where the macro would have been set and so
Matmul 8 bit tests were not being compiled and run on CPU builds.

### Motivation and Context
Improve test coverage for CPU builds
javier-intel pushed a commit to intel/onnxruntime that referenced this pull request Jun 15, 2025
### Description

It seems like microsoft#24509 added
a guard for the 8 bit Matmul tests that depends on an MLAS macro being
set to compile and run on CPUs but that guard itself was preventing the
inclusion of the MLAS header where the macro would have been set and so
Matmul 8 bit tests were not being compiled and run on CPU builds.

### Motivation and Context
Improve test coverage for CPU builds
@snnn
Copy link
Contributor

snnn commented Sep 5, 2025

This PR has been included in the rel-1.22.0 branch. Removing the release:1.22.0 label.

quic-ankus pushed a commit to CodeLinaro/onnxruntime that referenced this pull request Nov 25, 2025
### Description

It seems like microsoft#24509 added
a guard for the 8 bit Matmul tests that depends on an MLAS macro being
set to compile and run on CPUs but that guard itself was preventing the
inclusion of the MLAS header where the macro would have been set and so
Matmul 8 bit tests were not being compiled and run on CPU builds.

### Motivation and Context
Improve test coverage for CPU builds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants