-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce lowbit quantized linear MPS kernels (#954)
Summary: The following is the directory structure of the submitted code under torchao ``` experimental/ ├── kernels/ │ └── mps/ │ ├── metal/ │ │ └── (metal shaders) │ ├── src/ │ │ └── (tensor agnostic mps kernel implementations) │ └── test/ │ │ └── (directly test mps kernel implementations) └── ops/ └── mps/ ├── register.mm ├── setup.py └── test/ └── (test torch custom ops) ``` Differential Revision: D63342895
- Loading branch information
1 parent
107e378
commit 6c16c82
Showing
15 changed files
with
1,641 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
/** | ||
* LowBit Quantized Linear for bitwidths that are divisors of 8. Hence the name. | ||
* | ||
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) | ||
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8) | ||
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 | ||
* @param[outputData] M x N output tensor of floating point dtype (same as input) | ||
* @param[sizes] The sizes involved in the order: M, K, N | ||
* | ||
* Dispatched threads: N x M x 1 | ||
*/ | ||
template<typename T, unsigned nbit, unsigned groupSize> | ||
kernel void divbit_mm( | ||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B; | ||
|
||
constexpr uint8_t zero_shift = 1 << (nbit - 1); | ||
constexpr uint8_t values_per_byte = 8 / nbit; | ||
constexpr uint8_t minimask = (1 << nbit) - 1; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const T scale = scalesAndZeros[(kb * N + n) * 2 + 0]; | ||
const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift); | ||
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { | ||
const auto a_val = float(A_ptr[k]); | ||
uint8_t b_val = B_ptr[(n * K + k) / values_per_byte]; | ||
uint8_t shift = nbit * (k % values_per_byte); | ||
uint8_t mask = minimask << shift; | ||
b_val = (b_val & mask) >> shift; | ||
rc += a_val * float(scale * T(b_val) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void divbit_mm<DTYPE, NBIT, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_DIVBIT_MM(1, float, 32); | ||
INSTANTIATE_DIVBIT_MM(1, half, 32); | ||
INSTANTIATE_DIVBIT_MM(1, float, 64); | ||
INSTANTIATE_DIVBIT_MM(1, half, 64); | ||
INSTANTIATE_DIVBIT_MM(1, float, 128); | ||
INSTANTIATE_DIVBIT_MM(1, half, 128); | ||
INSTANTIATE_DIVBIT_MM(1, float, 256); | ||
INSTANTIATE_DIVBIT_MM(1, half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 32); | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 64); | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 128); | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 256); | ||
#endif | ||
|
||
INSTANTIATE_DIVBIT_MM(2, float, 32); | ||
INSTANTIATE_DIVBIT_MM(2, half, 32); | ||
INSTANTIATE_DIVBIT_MM(2, float, 64); | ||
INSTANTIATE_DIVBIT_MM(2, half, 64); | ||
INSTANTIATE_DIVBIT_MM(2, float, 128); | ||
INSTANTIATE_DIVBIT_MM(2, half, 128); | ||
INSTANTIATE_DIVBIT_MM(2, float, 256); | ||
INSTANTIATE_DIVBIT_MM(2, half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 32); | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 64); | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 128); | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 256); | ||
#endif | ||
|
||
INSTANTIATE_DIVBIT_MM(4, float, 32); | ||
INSTANTIATE_DIVBIT_MM(4, half, 32); | ||
INSTANTIATE_DIVBIT_MM(4, float, 64); | ||
INSTANTIATE_DIVBIT_MM(4, half, 64); | ||
INSTANTIATE_DIVBIT_MM(4, float, 128); | ||
INSTANTIATE_DIVBIT_MM(4, half, 128); | ||
INSTANTIATE_DIVBIT_MM(4, float, 256); | ||
INSTANTIATE_DIVBIT_MM(4, half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 32); | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 64); | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 128); | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 256); | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
/** | ||
* 3-Bit Quantized Linear. | ||
* | ||
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) | ||
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8) | ||
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 | ||
* @param[outputData] M x N output tensor of floating point dtype (same as input) | ||
* @param[sizes] The sizes involved in the order: M, K, N | ||
* | ||
* Dispatched threads: N x M x 1 | ||
*/ | ||
template<typename T, unsigned groupSize> | ||
kernel void int3pack_mm( | ||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B + n * 3 * K / 8; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); | ||
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4); | ||
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { | ||
const auto a_val0 = float(A_ptr[k + 0]); | ||
const auto a_val1 = float(A_ptr[k + 1]); | ||
const auto a_val2 = float(A_ptr[k + 2]); | ||
const auto a_val3 = float(A_ptr[k + 3]); | ||
const auto a_val4 = float(A_ptr[k + 4]); | ||
const auto a_val5 = float(A_ptr[k + 5]); | ||
const auto a_val6 = float(A_ptr[k + 6]); | ||
const auto a_val7 = float(A_ptr[k + 7]); | ||
|
||
uchar b0 = B_ptr[3 * (k / 8) + 0]; | ||
uchar b1 = B_ptr[3 * (k / 8) + 1]; | ||
uchar b2 = B_ptr[3 * (k / 8) + 2]; | ||
|
||
uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3); | ||
uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2); | ||
uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4); | ||
uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6); | ||
|
||
uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3); | ||
uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2); | ||
uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4); | ||
uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6); | ||
|
||
rc += a_val0 * (scale * float(w_val0) + zero); | ||
rc += a_val1 * (scale * float(w_val1) + zero); | ||
rc += a_val2 * (scale * float(w_val2) + zero); | ||
rc += a_val3 * (scale * float(w_val3) + zero); | ||
rc += a_val4 * (scale * float(w_val4) + zero); | ||
rc += a_val5 * (scale * float(w_val5) + zero); | ||
rc += a_val6 * (scale * float(w_val6) + zero); | ||
rc += a_val7 * (scale * float(w_val7) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void int3pack_mm<DTYPE, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_INT3MM(float, 32); | ||
INSTANTIATE_INT3MM(half, 32); | ||
INSTANTIATE_INT3MM(float, 64); | ||
INSTANTIATE_INT3MM(half, 64); | ||
INSTANTIATE_INT3MM(float, 128); | ||
INSTANTIATE_INT3MM(half, 128); | ||
INSTANTIATE_INT3MM(float, 256); | ||
INSTANTIATE_INT3MM(half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_INT3MM(bfloat, 32); | ||
INSTANTIATE_INT3MM(bfloat, 64); | ||
INSTANTIATE_INT3MM(bfloat, 128); | ||
INSTANTIATE_INT3MM(bfloat, 256); | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
/** | ||
* 5-Bit Quantized Linear. | ||
* | ||
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) | ||
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8) | ||
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 | ||
* @param[outputData] M x N output tensor of floating point dtype (same as input) | ||
* @param[sizes] The sizes involved in the order: M, K, N | ||
* | ||
* Dispatched threads: N x M x 1 | ||
*/ | ||
template<typename T, unsigned groupSize> | ||
kernel void int5pack_mm( | ||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B + n * 5 * K / 8; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); | ||
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(16); | ||
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { | ||
const auto a_val0 = float(A_ptr[k + 0]); | ||
const auto a_val1 = float(A_ptr[k + 1]); | ||
const auto a_val2 = float(A_ptr[k + 2]); | ||
const auto a_val3 = float(A_ptr[k + 3]); | ||
const auto a_val4 = float(A_ptr[k + 4]); | ||
const auto a_val5 = float(A_ptr[k + 5]); | ||
const auto a_val6 = float(A_ptr[k + 6]); | ||
const auto a_val7 = float(A_ptr[k + 7]); | ||
|
||
uchar b0 = B_ptr[5 * (k / 8) + 0]; | ||
uchar b1 = B_ptr[5 * (k / 8) + 1]; | ||
uchar b2 = B_ptr[5 * (k / 8) + 2]; | ||
uchar b3 = B_ptr[5 * (k / 8) + 3]; | ||
uchar b4 = B_ptr[5 * (k / 8) + 4]; | ||
|
||
uchar w_val0 = ((b0 & 1) << 4) | (b1 & 15); | ||
uchar w_val1 = ((b0 & 2) << 3) | ((b1 & 240) >> 4); | ||
uchar w_val2 = ((b0 & 4) << 2) | (b2 & 15); | ||
uchar w_val3 = ((b0 & 8) << 1) | ((b2 & 240) >> 4); | ||
|
||
uchar w_val4 = ((b0 & 16)) | (b3 & 15); | ||
uchar w_val5 = ((b0 & 32) >> 1) | ((b3 & 240) >> 4); | ||
uchar w_val6 = ((b0 & 64) >> 2) | (b4 & 15); | ||
uchar w_val7 = ((b0 & 128) >> 3) | ((b4 & 240) >> 4); | ||
|
||
rc += a_val0 * (scale * float(w_val0) + zero); | ||
rc += a_val1 * (scale * float(w_val1) + zero); | ||
rc += a_val2 * (scale * float(w_val2) + zero); | ||
rc += a_val3 * (scale * float(w_val3) + zero); | ||
rc += a_val4 * (scale * float(w_val4) + zero); | ||
rc += a_val5 * (scale * float(w_val5) + zero); | ||
rc += a_val6 * (scale * float(w_val6) + zero); | ||
rc += a_val7 * (scale * float(w_val7) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_INT5MM(DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int5pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void int5pack_mm<DTYPE, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_INT5MM(float, 32); | ||
INSTANTIATE_INT5MM(half, 32); | ||
INSTANTIATE_INT5MM(float, 64); | ||
INSTANTIATE_INT5MM(half, 64); | ||
INSTANTIATE_INT5MM(float, 128); | ||
INSTANTIATE_INT5MM(half, 128); | ||
INSTANTIATE_INT5MM(float, 256); | ||
INSTANTIATE_INT5MM(half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_INT5MM(bfloat, 32); | ||
INSTANTIATE_INT5MM(bfloat, 64); | ||
INSTANTIATE_INT5MM(bfloat, 128); | ||
INSTANTIATE_INT5MM(bfloat, 256); | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
/** | ||
* 6-Bit Quantized Linear. | ||
* | ||
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) | ||
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (6 * K / 8) | ||
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 | ||
* @param[outputData] M x N output tensor of floating point dtype (same as input) | ||
* @param[sizes] The sizes involved in the order: M, K, N | ||
* | ||
* Dispatched threads: N x M x 1 | ||
*/ | ||
template<typename T, unsigned groupSize> | ||
kernel void int6pack_mm( | ||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B + n * 3 * K / 4; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); | ||
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(32); | ||
for(uint idx = 0; idx < groupSize && k < K; idx+=4, k+=4) { | ||
const auto a_val0 = float(A_ptr[k + 0]); | ||
const auto a_val1 = float(A_ptr[k + 1]); | ||
const auto a_val2 = float(A_ptr[k + 2]); | ||
const auto a_val3 = float(A_ptr[k + 3]); | ||
|
||
uchar b0 = B_ptr[3 * (k / 4) + 0]; | ||
uchar b1 = B_ptr[3 * (k / 4) + 1]; | ||
uchar b2 = B_ptr[3 * (k / 4) + 2]; | ||
|
||
uchar w_val0 = ((b0 & 3) << 4) | (b1 & 15); | ||
uchar w_val1 = ((b0 & 12) << 2) | ((b1 & 240) >> 4); | ||
uchar w_val2 = ((b0 & 48)) | (b2 & 15); | ||
uchar w_val3 = ((b0 & 192) >> 2) | ((b2 & 240) >> 4); | ||
|
||
rc += a_val0 * (scale * float(w_val0) + zero); | ||
rc += a_val1 * (scale * float(w_val1) + zero); | ||
rc += a_val2 * (scale * float(w_val2) + zero); | ||
rc += a_val3 * (scale * float(w_val3) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_INT6MM(DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int6pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void int6pack_mm<DTYPE, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_INT6MM(float, 32); | ||
INSTANTIATE_INT6MM(half, 32); | ||
INSTANTIATE_INT6MM(float, 64); | ||
INSTANTIATE_INT6MM(half, 64); | ||
INSTANTIATE_INT6MM(float, 128); | ||
INSTANTIATE_INT6MM(half, 128); | ||
INSTANTIATE_INT6MM(float, 256); | ||
INSTANTIATE_INT6MM(half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_INT6MM(bfloat, 32); | ||
INSTANTIATE_INT6MM(bfloat, 64); | ||
INSTANTIATE_INT6MM(bfloat, 128); | ||
INSTANTIATE_INT6MM(bfloat, 256); | ||
#endif |
Oops, something went wrong.