From e783565372c39780a5c6198c14c9dd59b0481b07 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 26 Sep 2024 12:55:45 -0700 Subject: [PATCH] Introduce lowbit quantized linear MPS kernels (#954) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The following is the directory structure of the submitted code under torchao ``` experimental/ ├── kernels/ │ └── mps/ │ ├── metal/ │ │ └── (metal shaders) │ ├── src/ │ │ └── (tensor agnostic mps kernel implementations) │ └── test/ │ │ └── (directly test mps kernel implementations) └── ops/ └── mps/ ├── register.mm ├── setup.py └── test/ └── (test torch custom ops) ``` Differential Revision: D63342895 --- .../kernels/mps/metal/divbit.metal | 97 ++++++++ .../kernels/mps/metal/int1mm.metal | 63 +++++ .../kernels/mps/metal/int2mm.metal | 63 +++++ .../kernels/mps/metal/int3mm.metal | 88 +++++++ .../kernels/mps/metal/int3mm_vec.metal | 97 ++++++++ .../kernels/mps/metal/int4mm.metal | 80 +++++++ .../kernels/mps/metal/int4mv_opt.metal | 213 +++++++++++++++++ .../kernels/mps/metal/int5mm.metal | 90 +++++++ .../kernels/mps/metal/int6mm.metal | 75 ++++++ .../kernels/mps/metal/int7mm.metal | 92 ++++++++ .../kernels/mps/src/OperationUtils.h | 142 +++++++++++ .../experimental/kernels/mps/src/dispatch.h | 54 +++++ torchao/experimental/kernels/mps/src/lowbit.h | 223 ++++++++++++++++++ .../experimental/kernels/mps/src/packing.h | 192 +++++++++++++++ .../experimental/kernels/mps/test/Makefile | 7 + .../experimental/kernels/mps/test/bfloat16.h | 61 +++++ .../kernels/mps/test/test_lowbit.mm | 181 ++++++++++++++ torchao/experimental/ops/mps/register.mm | 114 +++++++++ torchao/experimental/ops/mps/setup.py | 23 ++ .../experimental/ops/mps/test/test_lowbit.py | 82 +++++++ 20 files changed, 2037 insertions(+) create mode 100644 torchao/experimental/kernels/mps/metal/divbit.metal create mode 100644 torchao/experimental/kernels/mps/metal/int1mm.metal create mode 100644 torchao/experimental/kernels/mps/metal/int2mm.metal create mode 100644 torchao/experimental/kernels/mps/metal/int3mm.metal create mode 100644 torchao/experimental/kernels/mps/metal/int3mm_vec.metal create mode 100644 torchao/experimental/kernels/mps/metal/int4mm.metal create mode 100644 torchao/experimental/kernels/mps/metal/int4mv_opt.metal create mode 100644 torchao/experimental/kernels/mps/metal/int5mm.metal create mode 100644 torchao/experimental/kernels/mps/metal/int6mm.metal create mode 100644 torchao/experimental/kernels/mps/metal/int7mm.metal create mode 100644 torchao/experimental/kernels/mps/src/OperationUtils.h create mode 100644 torchao/experimental/kernels/mps/src/dispatch.h create mode 100644 torchao/experimental/kernels/mps/src/lowbit.h create mode 100644 torchao/experimental/kernels/mps/src/packing.h create mode 100644 torchao/experimental/kernels/mps/test/Makefile create mode 100644 torchao/experimental/kernels/mps/test/bfloat16.h create mode 100644 torchao/experimental/kernels/mps/test/test_lowbit.mm create mode 100644 torchao/experimental/ops/mps/register.mm create mode 100644 torchao/experimental/ops/mps/setup.py create mode 100644 torchao/experimental/ops/mps/test/test_lowbit.py diff --git a/torchao/experimental/kernels/mps/metal/divbit.metal b/torchao/experimental/kernels/mps/metal/divbit.metal new file mode 100644 index 0000000000..d3b311853b --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/divbit.metal @@ -0,0 +1,97 @@ +#include +using namespace metal; + +// dispatchThreads:MTLSizeMake(N, M, 1) + +template +kernel void divbit_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B; + + constexpr uint8_t zero_shift = 1 << (nbit - 1); + constexpr uint8_t inv_nbit = 8 / nbit; + constexpr uint8_t minimask = (1 << nbit) - 1; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const T scale = scalesAndZeros[(kb * N + n) * 2 + 0]; + const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift); + for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { + const auto a_val = float(A_ptr[k]); + uint8_t b_val = B_ptr[(n * K + k) / inv_nbit]; + uint8_t shift = nbit * (k % inv_nbit); + uint8_t mask = minimask << shift; + b_val = (b_val & mask) >> shift; + rc += a_val * float(scale * T(b_val) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \ +template \ +[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void divbit_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_DIVBIT_MM(1, float, 32); +INSTANTIATE_DIVBIT_MM(1, half, 32); +INSTANTIATE_DIVBIT_MM(1, float, 64); +INSTANTIATE_DIVBIT_MM(1, half, 64); +INSTANTIATE_DIVBIT_MM(1, float, 128); +INSTANTIATE_DIVBIT_MM(1, half, 128); +INSTANTIATE_DIVBIT_MM(1, float, 256); +INSTANTIATE_DIVBIT_MM(1, half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_DIVBIT_MM(1, bfloat, 32); +INSTANTIATE_DIVBIT_MM(1, bfloat, 64); +INSTANTIATE_DIVBIT_MM(1, bfloat, 128); +INSTANTIATE_DIVBIT_MM(1, bfloat, 256); +#endif + +INSTANTIATE_DIVBIT_MM(2, float, 32); +INSTANTIATE_DIVBIT_MM(2, half, 32); +INSTANTIATE_DIVBIT_MM(2, float, 64); +INSTANTIATE_DIVBIT_MM(2, half, 64); +INSTANTIATE_DIVBIT_MM(2, float, 128); +INSTANTIATE_DIVBIT_MM(2, half, 128); +INSTANTIATE_DIVBIT_MM(2, float, 256); +INSTANTIATE_DIVBIT_MM(2, half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_DIVBIT_MM(2, bfloat, 32); +INSTANTIATE_DIVBIT_MM(2, bfloat, 64); +INSTANTIATE_DIVBIT_MM(2, bfloat, 128); +INSTANTIATE_DIVBIT_MM(2, bfloat, 256); +#endif + +INSTANTIATE_DIVBIT_MM(4, float, 32); +INSTANTIATE_DIVBIT_MM(4, half, 32); +INSTANTIATE_DIVBIT_MM(4, float, 64); +INSTANTIATE_DIVBIT_MM(4, half, 64); +INSTANTIATE_DIVBIT_MM(4, float, 128); +INSTANTIATE_DIVBIT_MM(4, half, 128); +INSTANTIATE_DIVBIT_MM(4, float, 256); +INSTANTIATE_DIVBIT_MM(4, half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_DIVBIT_MM(4, bfloat, 32); +INSTANTIATE_DIVBIT_MM(4, bfloat, 64); +INSTANTIATE_DIVBIT_MM(4, bfloat, 128); +INSTANTIATE_DIVBIT_MM(4, bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int1mm.metal b/torchao/experimental/kernels/mps/metal/int1mm.metal new file mode 100644 index 0000000000..f35aeba7ba --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int1mm.metal @@ -0,0 +1,63 @@ +#include +using namespace metal; + +// dispatchThreads:MTLSizeMake(N, M, 1) + +template +kernel void int1pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale; + for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { + const auto a_val = float(A_ptr[k]); + uint8_t b_val = B_ptr[(n * K + k) / 8]; + uint8_t shift = (k % 8); + uint8_t mask = 1 << shift; + b_val = (b_val & mask) >> shift; + rc += a_val * (scale * float(b_val) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT1MM(DTYPE, GSIZE) \ +template \ +[[host_name("int1pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int1pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT1MM(float, 32); +INSTANTIATE_INT1MM(half, 32); +INSTANTIATE_INT1MM(float, 64); +INSTANTIATE_INT1MM(half, 64); +INSTANTIATE_INT1MM(float, 128); +INSTANTIATE_INT1MM(half, 128); +INSTANTIATE_INT1MM(float, 256); +INSTANTIATE_INT1MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT1MM(bfloat, 32); +INSTANTIATE_INT1MM(bfloat, 64); +INSTANTIATE_INT1MM(bfloat, 128); +INSTANTIATE_INT1MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int2mm.metal b/torchao/experimental/kernels/mps/metal/int2mm.metal new file mode 100644 index 0000000000..73ce1b2566 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int2mm.metal @@ -0,0 +1,63 @@ +#include +using namespace metal; + +// dispatchThreads:MTLSizeMake(N, M, 1) + +template +kernel void int2pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = scalesAndZeros[(kb * N + n) * 2 + 0]; + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(2); + for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { + const auto a_val = float(A_ptr[k]); + uint8_t b_val = B_ptr[(n * K + k) / 4]; + uint8_t shift = 2 * (k % 4); + uint8_t mask = 3 << shift; + b_val = (b_val & mask) >> shift; + rc += a_val * (scale * float(b_val) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT2MM(DTYPE, GSIZE) \ +template \ +[[host_name("int2pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int2pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT2MM(float, 32); +INSTANTIATE_INT2MM(half, 32); +INSTANTIATE_INT2MM(float, 64); +INSTANTIATE_INT2MM(half, 64); +INSTANTIATE_INT2MM(float, 128); +INSTANTIATE_INT2MM(half, 128); +INSTANTIATE_INT2MM(float, 256); +INSTANTIATE_INT2MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT2MM(bfloat, 32); +INSTANTIATE_INT2MM(bfloat, 64); +INSTANTIATE_INT2MM(bfloat, 128); +INSTANTIATE_INT2MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int3mm.metal b/torchao/experimental/kernels/mps/metal/int3mm.metal new file mode 100644 index 0000000000..a58e9cb38f --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int3mm.metal @@ -0,0 +1,88 @@ +#include +using namespace metal; + +// dispatchThreads:MTLSizeMake(N, M, 1) + +template +kernel void int3pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B + n * 3 * K / 8; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4); + for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { + const auto a_val0 = float(A_ptr[k + 0]); + const auto a_val1 = float(A_ptr[k + 1]); + const auto a_val2 = float(A_ptr[k + 2]); + const auto a_val3 = float(A_ptr[k + 3]); + const auto a_val4 = float(A_ptr[k + 4]); + const auto a_val5 = float(A_ptr[k + 5]); + const auto a_val6 = float(A_ptr[k + 6]); + const auto a_val7 = float(A_ptr[k + 7]); + + uchar b0 = B_ptr[3 * (k / 8) + 0]; + uchar b1 = B_ptr[3 * (k / 8) + 1]; + uchar b2 = B_ptr[3 * (k / 8) + 2]; + + uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3); + uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2); + uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4); + uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6); + + uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3); + uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2); + uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4); + uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6); + + rc += a_val0 * (scale * float(w_val0) + zero); + rc += a_val1 * (scale * float(w_val1) + zero); + rc += a_val2 * (scale * float(w_val2) + zero); + rc += a_val3 * (scale * float(w_val3) + zero); + rc += a_val4 * (scale * float(w_val4) + zero); + rc += a_val5 * (scale * float(w_val5) + zero); + rc += a_val6 * (scale * float(w_val6) + zero); + rc += a_val7 * (scale * float(w_val7) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \ +template \ +[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int3pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT3MM(float, 32); +INSTANTIATE_INT3MM(half, 32); +INSTANTIATE_INT3MM(float, 64); +INSTANTIATE_INT3MM(half, 64); +INSTANTIATE_INT3MM(float, 128); +INSTANTIATE_INT3MM(half, 128); +INSTANTIATE_INT3MM(float, 256); +INSTANTIATE_INT3MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT3MM(bfloat, 32); +INSTANTIATE_INT3MM(bfloat, 64); +INSTANTIATE_INT3MM(bfloat, 128); +INSTANTIATE_INT3MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int3mm_vec.metal b/torchao/experimental/kernels/mps/metal/int3mm_vec.metal new file mode 100644 index 0000000000..d2d2d69e42 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int3mm_vec.metal @@ -0,0 +1,97 @@ +#include +using namespace metal; + +template struct Vec4Type {}; + +template <> struct Vec4Type { + using type = float4; +}; + +template <> struct Vec4Type { + using type = half4; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec4Type { + using type = bfloat4; +}; +#endif + +// dispatchThreads:MTLSizeMake(N, M, 1) + +template +kernel void int3pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + + using vecT = typename Vec4Type::type; + constant vecT *A_ptr = reinterpret_cast(A + m * K); + constant uchar *B_ptr = B + n * 3 * K / 8; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4); + for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { + const auto a_vals0 = float4(A_ptr[k/4]); + const auto a_vals1 = float4(A_ptr[k/4 + 1]); + + uchar b2 = B_ptr[3 * (k / 8) + 0]; + uchar b10_0 = B_ptr[3 * (k / 8) + 1]; + uchar b10_1 = B_ptr[3 * (k / 8) + 2]; + + uchar w_val0 = ((b2 & 1) << 2) | (b10_0 & 3); + uchar w_val1 = ((b2 & 2) << 1) | ((b10_0 & 12) >> 2); + uchar w_val2 = (b2 & 4) | ((b10_0 & 48) >> 4); + uchar w_val3 = ((b2 & 8) >> 1) | ((b10_0 & 192) >> 6); + + uchar w_val4 = ((b2 & 16) >> 2) | (b10_1 & 3); + uchar w_val5 = ((b2 & 32) >> 3) | ((b10_1 & 12) >> 2); + uchar w_val6 = ((b2 & 64) >> 4) | ((b10_1 & 48) >> 4); + uchar w_val7 = ((b2 & 128) >> 5) | ((b10_1 & 192) >> 6); + + uchar4 w_vals0 = uchar4(w_val0, w_val1, w_val2, w_val3); + uchar4 w_vals1 = uchar4(w_val4, w_val5, w_val6, w_val7); + + rc += dot(a_vals0, scale * float4(w_vals0) + zero); + rc += dot(a_vals1, scale * float4(w_vals1) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \ +template \ +[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int3pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT3MM(float, 32); +INSTANTIATE_INT3MM(half, 32); +INSTANTIATE_INT3MM(float, 64); +INSTANTIATE_INT3MM(half, 64); +INSTANTIATE_INT3MM(float, 128); +INSTANTIATE_INT3MM(half, 128); +INSTANTIATE_INT3MM(float, 256); +INSTANTIATE_INT3MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT3MM(bfloat, 32); +INSTANTIATE_INT3MM(bfloat, 64); +INSTANTIATE_INT3MM(bfloat, 128); +INSTANTIATE_INT3MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int4mm.metal b/torchao/experimental/kernels/mps/metal/int4mm.metal new file mode 100644 index 0000000000..5c2e834883 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int4mm.metal @@ -0,0 +1,80 @@ +#include +using namespace metal; + +// dispatchThreads:MTLSizeMake(N, M, 1) + +kernel void weight_to_int4pack(constant int *W [[buffer(0)]], + device uchar *outputData [[buffer(1)]], + constant uint2 &sizes [[buffer(2)]], + uint2 thread_index [[thread_position_in_grid]]) { + const uint N = sizes.x; + const uint K = sizes.y; + const uint n = thread_index.x; // 0..N-1 + const uint k2 = thread_index.y; // 0..K/2-1 + int32_t src_val0 = W[n * K + 2 * k2]; + int32_t src_val1 = W[n * K + 2 * k2 + 1]; + outputData[n * (K / 2) + k2] = (uint8_t(src_val1) << 4) | uint8_t(src_val0); +} + +template +kernel void int4pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(8); + for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { + const auto a_val = float(A_ptr[k]); + uchar b_val = B_ptr[(n * K + k) / 2]; + + // b_val = (k & 1) == 0 ? b_val & 0x0f : (b_val >> 4); + + uint8_t shift = 4 * (k % 2); + uint8_t mask = 15 << shift; + b_val = (b_val & mask) >> shift; + + rc += a_val * (scale * float(b_val) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT4MM(DTYPE, GSIZE) \ +template \ +[[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int4pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT4MM(float, 32); +INSTANTIATE_INT4MM(half, 32); +INSTANTIATE_INT4MM(float, 64); +INSTANTIATE_INT4MM(half, 64); +INSTANTIATE_INT4MM(float, 128); +INSTANTIATE_INT4MM(half, 128); +INSTANTIATE_INT4MM(float, 256); +INSTANTIATE_INT4MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT4MM(bfloat, 32); +INSTANTIATE_INT4MM(bfloat, 64); +INSTANTIATE_INT4MM(bfloat, 128); +INSTANTIATE_INT4MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int4mv_opt.metal b/torchao/experimental/kernels/mps/metal/int4mv_opt.metal new file mode 100644 index 0000000000..b54c3e7d44 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int4mv_opt.metal @@ -0,0 +1,213 @@ +#include +#include +using namespace metal; + +kernel void weight_to_int4pack(constant int *W [[buffer(0)]], + device uchar *outputData [[buffer(1)]], + constant uint2 &sizes [[buffer(2)]], + uint2 thread_index [[thread_position_in_grid]]) { + const uint N = sizes.x; + const uint K = sizes.y; + const uint n = thread_index.x; // 0..N-1 + const uint k2 = thread_index.y; // 0..K/2-1 + int32_t src_val0 = W[n * K + 2 * k2]; + int32_t src_val1 = W[n * K + 2 * k2 + 1]; + outputData[n * (K / 2) + k2] = (uint8_t(src_val1) << 4) | uint8_t(src_val0); +} + +template struct Vec4Type {}; + +template <> struct Vec4Type { using type = float4; }; + +template <> struct Vec4Type { using type = half4; }; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec4Type { using type = bfloat4; }; +#endif + +/* + This code takes heavy inspiration from MLX qvm kernel here: + https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.metal#L381 + Specifically: + - Multiplying activation by inverse scaling factor to reduce compute + boundedness + - Handling zero point by accumulating act in separate sum term. Needed with + optimization done above. MLX MIT License: + https://github.com/ml-explore/mlx/blob/main/LICENSE +*/ + +/* + A matrix is [M x K] (right now this kernel does not support M > 1 but this is + a very easy fix that will follow right after) B matrix is [N x K]. For 4 bit + 2 of the k values are packed in one byte so you can think of B as [N x K/2] + matrix from layout perspective. + + Since this kernel is optimizing for gemv case, we split work, along reduction + dim k, among the threads of same simdgroup. Ex: if K = 4096 and simdgroup + size is 32 (current algorithm should work as long as simdgroup size is > 32). + Then each thread will accumulate 4096/32 = 128 k values. However these 128 + values, handled by each thread are not laid out contiguously. Each thread + handles 4 contiguous k values and then jumps 128 elements, k_jump = + thread_per_channel (32) * ks_per_thread (4). Take a simpler example where + simdgroup is of size 4. In this case threads_per_channel = 4. Assume K = 32 + k thread + [0, 1, 2, 3, 0 + 4, 5, 6, 7, 1 + 8, 9, 10, 11, 2 + 12, 13, 14, 15, 3 + 16, 17, 18, 19, 0 + 20, 21, 22, 23, 1 + 24, 25, 26, 27, 2 + 28, 29, 30, 31] 3 + thread id in simd group that handle corresponding + ks + Thread 0 here is handling (0, 1, 2, 3) and then (16, 17, 18, 19). They are + apart by k_jump = 4 * 4 = 16 This is done to improve memory access locality + amonng threads that are working co-operatively. Once each thread has their + partial sums accumulated, we use tree reduction (Metal offers simd_sum but + not used so that we support simdgroup size = 64). In the + example above we will have 4 partial sums. + + Each thread also handles 4 different output rows. Thus each simdgroup will be + responsible for (1x4) tile of the output. We haven't evaluated whether a + different tile size is better or not. We probably will do some auto-tuning + once initial work is done. + +*/ + +/* + @brief This shader implements 4-bit matrix-vector multiplication where A + matrix is fp16, bfloat or float and B matrix is a 4-bit groupwise-quantized weight + matrix. + @param [in] A is activation matrix of size M x K. + @param [in] B is weight matrix of size M x K. Each byte contains 2 4-bit + values, along K dim, packed together. + @param [in] scales_and_zeros is scales and zero points corresponding each + output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output + @param [out] output_data is output matrix of size M x N. + @param [in] sizes array contains values of M, N and K. + @param [in] thread_index is global thread id. + @param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31]. +*/ +/* +TODOs: + Right now code handles only M = 1 case. Fix that. +*/ +template +kernel void int4pack_mm(constant T *A [[buffer(0)]], + constant uchar *B [[buffer(1)]], + constant T *scales_and_zeros [[buffer(2)]], + device T *output_data [[buffer(3)]], + constant uint3 &sizes [[buffer(4)]], // M, K, N + uint3 thread_index [[thread_position_in_grid]], + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) { + constexpr uint threads_per_channel = 32; + constexpr uint ks_per_thread = 4; + constexpr uint k_pack_factor = 2; + const uint K = sizes.y; + const uint N = sizes.z; + uint n = thread_index.x; // 0..N/4-1 + uint m = thread_index.z; // 0..M + n = n / threads_per_channel; + n = n * 4; + // This is starting k for each thread. In the example above, for thread 1 this + // value will be 4. + uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread; + constexpr int k_jump = threads_per_channel * ks_per_thread; + + using vecT = typename Vec4Type::type; + constant vecT *A_ptr = reinterpret_cast(A + m * K); + constant uchar *B_ptr = B + ((n * K) / k_pack_factor); + + thread float4 result = float4(0.0); + // We multipy group of 4 channels with these scales. + // Because corresponding values from weight matrix are effectively left + // shifted. This is to avoid doing right shift on those values which ends up + // affecting performance. This is the trick applied in MLX kernels. + float4 act_div_scales = {1.f, 1 / 16.f, 1 / 256.f, 1 / 4096.f}; + + for (; k < K; k += k_jump) { + // Find specific group to which channels handled by this thread + // belong. + uint k_block_index = k / group_size; + // Since scales_and_zeros are packed as [num_groups, N, 2]. + // Finding a specific's group's scales and zero points requires jump by factor + // of N*2 + uint scales_group_offset = (k_block_index * N + n) * 2; + uint zeros_gruop_offset = scales_group_offset + 1; + + const T scale0 = scales_and_zeros[scales_group_offset]; + // Adding zero point results in 10% perf penalty. + const T zero0 = scales_and_zeros[zeros_gruop_offset] - scale0 * T(8); + + const T scale1 = scales_and_zeros[scales_group_offset + 2]; + const T zero1 = scales_and_zeros[zeros_gruop_offset + 2] - scale1 * T(8); + + const T scale2 = scales_and_zeros[scales_group_offset + 4]; + const T zero2 = scales_and_zeros[zeros_gruop_offset + 4] - scale2 * T(8); + + const T scale3 = scales_and_zeros[scales_group_offset + 6]; + const T zero3 = scales_and_zeros[zeros_gruop_offset + 6] - scale3 * T(8); + + const float4 zeros = float4(zero0, zero1, zero2, zero3); + + float4 a_val = float4(A_ptr[k / 4]); + // We are gonna skip right-shifts of the weights and hence divide by corresponding factor. + float4 a_vec = a_val * act_div_scales; + float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3]; + + float4x4 b_mat; + ushort b_val0 = (reinterpret_cast( + B_ptr + (k + 0 * K) / k_pack_factor))[0]; + ushort b_val1 = (reinterpret_cast( + B_ptr + (k + 1 * K) / k_pack_factor))[0]; + ushort b_val2 = (reinterpret_cast( + B_ptr + (k + 2 * K) / k_pack_factor))[0]; + ushort b_val3 = (reinterpret_cast( + B_ptr + (k + 3 * K) / k_pack_factor))[0]; + b_mat[0] = scale0 * float4(float(b_val0 & 0x000f), float(b_val0 & 0x00f0), + float(b_val0 & 0x0f00), float(b_val0 & 0xf000)); + b_mat[1] = scale1 * float4(float(b_val1 & 0x000f), float(b_val1 & 0x00f0), + float(b_val1 & 0x0f00), float(b_val1 & 0xf000)); + b_mat[2] = scale2 * float4(float(b_val2 & 0x000f), float(b_val2 & 0x00f0), + float(b_val2 & 0x0f00), float(b_val2 & 0xf000)); + b_mat[3] = scale3 * float4(float(b_val3 & 0x000f), float(b_val3 & 0x00f0), + float(b_val3 & 0x0f00), float(b_val3 & 0xf000)); + + result += a_vec * b_mat; + result += a_val_sum * zeros; + } + result += simd_shuffle_down(result, 1); + result += simd_shuffle_down(result, 2); + result += simd_shuffle_down(result, 4); + result += simd_shuffle_down(result, 8); + result += simd_shuffle_down(result, 16); + if (tid_in_simdgroup % threads_per_channel == 0) { + reinterpret_cast(output_data + m * N)[n / 4] = vecT(result); + } +} + +#define INSTANTIATE_INT4MV(DTYPE, GSIZE) \ + template [[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \ + int4pack_mm( \ + constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \ + constant DTYPE * scales_and_zeros [[buffer(2)]], \ + device DTYPE * output_data [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint3 thread_index [[thread_position_in_grid]], \ + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) + +INSTANTIATE_INT4MV(float, 32); +INSTANTIATE_INT4MV(half, 32); +INSTANTIATE_INT4MV(float, 64); +INSTANTIATE_INT4MV(half, 64); +INSTANTIATE_INT4MV(float, 128); +INSTANTIATE_INT4MV(half, 128); +INSTANTIATE_INT4MV(float, 256); +INSTANTIATE_INT4MV(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT4MV(bfloat, 32); +INSTANTIATE_INT4MV(bfloat, 64); +INSTANTIATE_INT4MV(bfloat, 128); +INSTANTIATE_INT4MV(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int5mm.metal b/torchao/experimental/kernels/mps/metal/int5mm.metal new file mode 100644 index 0000000000..f3f9921e54 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int5mm.metal @@ -0,0 +1,90 @@ +#include +using namespace metal; + +// dispatchThreads:MTLSizeMake(N, M, 1) + +template +kernel void int5pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B + n * 5 * K / 8; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(16); + for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { + const auto a_val0 = float(A_ptr[k + 0]); + const auto a_val1 = float(A_ptr[k + 1]); + const auto a_val2 = float(A_ptr[k + 2]); + const auto a_val3 = float(A_ptr[k + 3]); + const auto a_val4 = float(A_ptr[k + 4]); + const auto a_val5 = float(A_ptr[k + 5]); + const auto a_val6 = float(A_ptr[k + 6]); + const auto a_val7 = float(A_ptr[k + 7]); + + uchar b0 = B_ptr[5 * (k / 8) + 0]; + uchar b1 = B_ptr[5 * (k / 8) + 1]; + uchar b2 = B_ptr[5 * (k / 8) + 2]; + uchar b3 = B_ptr[5 * (k / 8) + 3]; + uchar b4 = B_ptr[5 * (k / 8) + 4]; + + uchar w_val0 = ((b0 & 1) << 4) | (b1 & 15); + uchar w_val1 = ((b0 & 2) << 3) | ((b1 & 240) >> 4); + uchar w_val2 = ((b0 & 4) << 2) | (b2 & 15); + uchar w_val3 = ((b0 & 8) << 1) | ((b2 & 240) >> 4); + + uchar w_val4 = ((b0 & 16)) | (b3 & 15); + uchar w_val5 = ((b0 & 32) >> 1) | ((b3 & 240) >> 4); + uchar w_val6 = ((b0 & 64) >> 2) | (b4 & 15); + uchar w_val7 = ((b0 & 128) >> 3) | ((b4 & 240) >> 4); + + rc += a_val0 * (scale * float(w_val0) + zero); + rc += a_val1 * (scale * float(w_val1) + zero); + rc += a_val2 * (scale * float(w_val2) + zero); + rc += a_val3 * (scale * float(w_val3) + zero); + rc += a_val4 * (scale * float(w_val4) + zero); + rc += a_val5 * (scale * float(w_val5) + zero); + rc += a_val6 * (scale * float(w_val6) + zero); + rc += a_val7 * (scale * float(w_val7) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT5MM(DTYPE, GSIZE) \ +template \ +[[host_name("int5pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int5pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT5MM(float, 32); +INSTANTIATE_INT5MM(half, 32); +INSTANTIATE_INT5MM(float, 64); +INSTANTIATE_INT5MM(half, 64); +INSTANTIATE_INT5MM(float, 128); +INSTANTIATE_INT5MM(half, 128); +INSTANTIATE_INT5MM(float, 256); +INSTANTIATE_INT5MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT5MM(bfloat, 32); +INSTANTIATE_INT5MM(bfloat, 64); +INSTANTIATE_INT5MM(bfloat, 128); +INSTANTIATE_INT5MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int6mm.metal b/torchao/experimental/kernels/mps/metal/int6mm.metal new file mode 100644 index 0000000000..a2b252b796 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int6mm.metal @@ -0,0 +1,75 @@ +#include +using namespace metal; + +// dispatchThreads:MTLSizeMake(N, M, 1) + +template +kernel void int6pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B + n * 3 * K / 4; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(32); + for(uint idx = 0; idx < groupSize && k < K; idx+=4, k+=4) { + const auto a_val0 = float(A_ptr[k + 0]); + const auto a_val1 = float(A_ptr[k + 1]); + const auto a_val2 = float(A_ptr[k + 2]); + const auto a_val3 = float(A_ptr[k + 3]); + + uchar b0 = B_ptr[3 * (k / 4) + 0]; + uchar b1 = B_ptr[3 * (k / 4) + 1]; + uchar b2 = B_ptr[3 * (k / 4) + 2]; + + uchar w_val0 = ((b0 & 3) << 4) | (b1 & 15); + uchar w_val1 = ((b0 & 12) << 2) | ((b1 & 240) >> 4); + uchar w_val2 = ((b0 & 48)) | (b2 & 15); + uchar w_val3 = ((b0 & 192) >> 2) | ((b2 & 240) >> 4); + + rc += a_val0 * (scale * float(w_val0) + zero); + rc += a_val1 * (scale * float(w_val1) + zero); + rc += a_val2 * (scale * float(w_val2) + zero); + rc += a_val3 * (scale * float(w_val3) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT6MM(DTYPE, GSIZE) \ +template \ +[[host_name("int6pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int6pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT6MM(float, 32); +INSTANTIATE_INT6MM(half, 32); +INSTANTIATE_INT6MM(float, 64); +INSTANTIATE_INT6MM(half, 64); +INSTANTIATE_INT6MM(float, 128); +INSTANTIATE_INT6MM(half, 128); +INSTANTIATE_INT6MM(float, 256); +INSTANTIATE_INT6MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT6MM(bfloat, 32); +INSTANTIATE_INT6MM(bfloat, 64); +INSTANTIATE_INT6MM(bfloat, 128); +INSTANTIATE_INT6MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int7mm.metal b/torchao/experimental/kernels/mps/metal/int7mm.metal new file mode 100644 index 0000000000..cdb18f2026 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int7mm.metal @@ -0,0 +1,92 @@ +#include +using namespace metal; + +// dispatchThreads:MTLSizeMake(N, M, 1) + +template +kernel void int7pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B + n * 7 * K / 8; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(64); + for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { + const auto a_val0 = float(A_ptr[k + 0]); + const auto a_val1 = float(A_ptr[k + 1]); + const auto a_val2 = float(A_ptr[k + 2]); + const auto a_val3 = float(A_ptr[k + 3]); + const auto a_val4 = float(A_ptr[k + 4]); + const auto a_val5 = float(A_ptr[k + 5]); + const auto a_val6 = float(A_ptr[k + 6]); + const auto a_val7 = float(A_ptr[k + 7]); + + uchar b0 = B_ptr[7 * (k / 8) + 0]; + uchar b1 = B_ptr[7 * (k / 8) + 1]; + uchar b2 = B_ptr[7 * (k / 8) + 2]; + uchar b3 = B_ptr[7 * (k / 8) + 3]; + uchar b4 = B_ptr[7 * (k / 8) + 4]; + uchar b5 = B_ptr[7 * (k / 8) + 5]; + uchar b6 = B_ptr[7 * (k / 8) + 6]; + + uchar w_val0 = b0 & 127; + uchar w_val1 = b1 & 127; + uchar w_val2 = b2 & 127; + uchar w_val3 = b3 & 127; + uchar w_val4 = b4 & 127; + uchar w_val5 = b5 & 127; + uchar w_val6 = b6 & 127; + uchar w_val7 = ((b0 & 128) >> 7) | ((b1 & 128) >> 6) | ((b2 & 128) >> 5) | ((b3 & 128) >> 4) + | ((b4 & 128) >> 3) | ((b5 & 128) >> 2) | ((b6 & 128) >> 1); + + rc += a_val0 * (scale * float(w_val0) + zero); + rc += a_val1 * (scale * float(w_val1) + zero); + rc += a_val2 * (scale * float(w_val2) + zero); + rc += a_val3 * (scale * float(w_val3) + zero); + rc += a_val4 * (scale * float(w_val4) + zero); + rc += a_val5 * (scale * float(w_val5) + zero); + rc += a_val6 * (scale * float(w_val6) + zero); + rc += a_val7 * (scale * float(w_val7) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT7MM(DTYPE, GSIZE) \ +template \ +[[host_name("int7pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int7pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT7MM(float, 32); +INSTANTIATE_INT7MM(half, 32); +INSTANTIATE_INT7MM(float, 64); +INSTANTIATE_INT7MM(half, 64); +INSTANTIATE_INT7MM(float, 128); +INSTANTIATE_INT7MM(half, 128); +INSTANTIATE_INT7MM(float, 256); +INSTANTIATE_INT7MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT7MM(bfloat, 32); +INSTANTIATE_INT7MM(bfloat, 64); +INSTANTIATE_INT7MM(bfloat, 128); +INSTANTIATE_INT7MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/src/OperationUtils.h b/torchao/experimental/kernels/mps/src/OperationUtils.h new file mode 100644 index 0000000000..80857b1302 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/OperationUtils.h @@ -0,0 +1,142 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +static void fail(const std::string& str) { + std::cerr << str << std::endl; + abort(); +} + +static void fail(const std::string& str1, const std::string& str2) { + std::cerr << str1 << str2 << std::endl; + abort(); +} + +inline void dispatch_sync_with_rethrow( + id queue, + void (^block)()) { + (void)queue; + block(); +} + +inline id getMetalDevice() { + NSArray* devices = [MTLCopyAllDevices() autorelease]; + if (devices.count == 0) { + fail("Metal is not supported"); + } + return devices[0]; +} + +static id MTL_DEVICE = getMetalDevice(); + +static id compileLibraryFromSource( + id device, + const std::string& source) { + NSError* error = nil; + MTLCompileOptions* options = [MTLCompileOptions new]; + [options setLanguageVersion:MTLLanguageVersion3_1]; + NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; + id library = [device newLibraryWithSource:kernel_source + options:options + error:&error]; + if (library == nil) { + fail("Failed to compile: ", error.description.UTF8String); + } + return library; +} + +class MetalShaderLibrary { + public: + MetalShaderLibrary(const std::string& src) : shaderSource(src) { + lib = compileLibraryFromSource(device, shaderSource); + } + MetalShaderLibrary(const MetalShaderLibrary&) = delete; + MetalShaderLibrary(MetalShaderLibrary&&) = delete; + + id getPipelineStateForFunc( + const std::string& fname) { + return get_compute_pipeline_state(load_func(fname)); + } + + private: + std::string shaderSource; + id device = MTL_DEVICE; + id lib = nil; + + id load_func(const std::string& func_name) const { + id func = [lib + newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; + if (func == nil) { + fail("Can't get function:" + func_name); + } + return func; + } + + id get_compute_pipeline_state( + id func) const { + NSError* error = nil; + auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; + if (cpl == nil) { + fail( + "Failed to construct pipeline state: ", error.description.UTF8String); + } + return cpl; + } +}; + +class MPSStream { + public: + MPSStream() { + _commandQueue = [MTL_DEVICE newCommandQueue]; + } + + ~MPSStream() { + [_commandQueue release]; + _commandQueue = nil; + + assert(_commandBuffer == nil); + } + + id queue() const { + return _commandQueue; + } + + id commandBuffer() { + if (!_commandBuffer) { + auto desc = [MTLCommandBufferDescriptor new]; + desc.errorOptions = MTLCommandBufferErrorOptionEncoderExecutionStatus; + _commandBuffer = [_commandQueue commandBufferWithDescriptor:desc]; + } + return _commandBuffer; + } + + id commandEncoder() { + if (!_commandEncoder) { + _commandEncoder = [commandBuffer() computeCommandEncoder]; + } + return _commandEncoder; + } + + private: + id _commandQueue = nil; + id _commandBuffer = nil; + id _commandEncoder = nil; +}; + +inline void finalize_block(MPSStream* mpsStream) { + id encoder = mpsStream->commandEncoder(); + id cmdBuffer = mpsStream->commandBuffer(); + [encoder endEncoding]; + [cmdBuffer commit]; + [cmdBuffer waitUntilCompleted]; +} + +inline MPSStream* getCurrentMPSStream() { + return new MPSStream(); +} diff --git a/torchao/experimental/kernels/mps/src/dispatch.h b/torchao/experimental/kernels/mps/src/dispatch.h new file mode 100644 index 0000000000..54c07e4c64 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/dispatch.h @@ -0,0 +1,54 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +namespace torchao { +namespace kernels { +namespace mps { +namespace lowbit { +namespace dispatch { + +inline void dispatch_mm( + id encoder, + int32_t maxThreadsPerGroup, + int32_t M, + int32_t N, + int32_t K) { + (void)K; + [encoder dispatchThreads:MTLSizeMake(N, M, 1) + threadsPerThreadgroup:MTLSizeMake( + std::min( + static_cast( + maxThreadsPerGroup), + M), + 1, + 1)]; +} + +inline void dispatch_int4mv_opt( + id encoder, + int32_t maxThreadsPerGroup, + int32_t M, + int32_t N, + int32_t K) { + (void)K; + (void)maxThreadsPerGroup; + // constexpr auto blockSize = 8; + // if (maxThreadsPerGroup < blockSize * blockSize) { + // throw std::runtime_error("Can't dispatch!"); + // } + [encoder dispatchThreads:MTLSizeMake(N / 4 * 32, 1, M) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; +} + +} // namespace dispatch +} // namespace lowbit +} // namespace mps +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h new file mode 100644 index 0000000000..55c169ffe0 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -0,0 +1,223 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +#include +#include + +#include +#include + +#ifdef ATEN +#include +using namespace at::native::mps; +inline void finalize_block(MPSStream* mpsStream) {} +#else +#include +#endif + +namespace torchao { +namespace kernels { +namespace mps { +namespace lowbit { +namespace { + +static constexpr std::string_view METAL_SHADER_DIR = + "/Users/mcandales/fbsource/xplat/pytorch/ao/torchao/experimental/kernels/mps/metal"; + +template +struct LowBitConfig {}; + +template <> +struct LowBitConfig<1> { + static constexpr std::string_view metal_filename = "int1mm.metal"; + static constexpr std::string_view func_prefix = "int1pack_mm_"; + static constexpr auto packing_fn = packing::pack<1>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<2> { + static constexpr std::string_view metal_filename = "int2mm.metal"; + static constexpr std::string_view func_prefix = "int2pack_mm_"; + static constexpr auto packing_fn = packing::pack<2>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<3> { + static constexpr std::string_view metal_filename = "int3mm.metal"; + static constexpr std::string_view func_prefix = "int3pack_mm_"; + static constexpr auto packing_fn = packing::pack<3>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<4> { + static constexpr std::string_view metal_filename = "int4mv_opt.metal"; + static constexpr std::string_view func_prefix = "int4pack_mm_"; + static constexpr auto packing_fn = packing::pack<4>; + static constexpr auto dispatch_fn = dispatch::dispatch_int4mv_opt; +}; + +template <> +struct LowBitConfig<5> { + static constexpr std::string_view metal_filename = "int5mm.metal"; + static constexpr std::string_view func_prefix = "int5pack_mm_"; + static constexpr auto packing_fn = packing::pack<5>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<6> { + static constexpr std::string_view metal_filename = "int6mm.metal"; + static constexpr std::string_view func_prefix = "int6pack_mm_"; + static constexpr auto packing_fn = packing::pack<6>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<7> { + static constexpr std::string_view metal_filename = "int7mm.metal"; + static constexpr std::string_view func_prefix = "int7pack_mm_"; + static constexpr auto packing_fn = packing::pack<7>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +inline MetalShaderLibrary compileLibraryFromFile(const std::string& fname) { + std::ifstream ifs(fname); + std::stringstream ss; + ss << ifs.rdbuf(); + return MetalShaderLibrary(ss.str()); +} + +inline void linear_lowbit_quant_weights_mps_impl( + id a_buf, + id b_buf, + id sz_buf, + id out_buf, + int32_t M, + int32_t K, + int32_t N, + const std::string shader_path, + const std::string shader_func, + void (*dispatch_fn)( + id, + int32_t, + int32_t, + int32_t, + int32_t)) { + std::array sizes = { + static_cast(M), + static_cast(K), + static_cast(N), + 0}; + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MetalShaderLibrary lib = compileLibraryFromFile(shader_path); + id cpl = + lib.getPipelineStateForFunc(shader_func); + const auto maxThreadsPerGroup = [cpl maxTotalThreadsPerThreadgroup]; + [computeEncoder setComputePipelineState:cpl]; + [computeEncoder setBuffer:a_buf offset:0 atIndex:0]; + [computeEncoder setBuffer:b_buf offset:0 atIndex:1]; + [computeEncoder setBuffer:sz_buf offset:0 atIndex:2]; + [computeEncoder setBuffer:out_buf offset:0 atIndex:3]; + [computeEncoder setBytes:sizes.data() + length:sizeof(uint32_t) * sizes.size() + atIndex:4]; + dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K); + finalize_block(mpsStream); + } + }); +} + +// LowBit Quantized Weights Linear on Metal +template +void linear_lowbit_quant_weights_mps( + id a_buf, + id b_buf, + int64_t qGroupSize, + id sz_buf, + id out_buf, + int32_t M, + int32_t K, + int32_t N, + const std::string_view type_str) { + const std::string shader_path = std::string(METAL_SHADER_DIR) + "/" + + std::string(LowBitConfig::metal_filename); + const std::string shader_func = std::string(LowBitConfig::func_prefix) + + std::to_string(qGroupSize) + "_" + std::string(type_str); + return linear_lowbit_quant_weights_mps_impl( + a_buf, + b_buf, + sz_buf, + out_buf, + M, + K, + N, + shader_path, + shader_func, + LowBitConfig::dispatch_fn); +} + +} // namespace + +// LowBit Quantized Weights Linear & Packing on Metal +template +struct LowBitQuantWeights { + static constexpr auto linear = linear_lowbit_quant_weights_mps; + static constexpr auto pack = LowBitConfig::packing_fn; +}; + +// Reference CPU implementation of lowbit quantized linear +template +void reference_linear_lowbit_quant_weights_cpu( + const T* a_ptr, + const uint8_t* w_ptr, + int64_t group_size, + const T* sz_ptr, + T* out_ptr, + int32_t M, + int32_t K, + int32_t N, + int64_t nbit) { + uint8_t zero_shift = 1 << (nbit - 1); + + for (int32_t m = 0; m < M; m++) { + for (int32_t n = 0; n < N; n++) { + const int32_t k_block = (K + group_size - 1) / group_size; + const T* A_ptr = a_ptr + m * K; + + float rc = 0.0; + int32_t k = 0; + for (int32_t kb = 0; kb < k_block; kb++) { + const float scale = float(sz_ptr[(kb * N + n) * 2 + 0]); + const float zero = + float(sz_ptr[(kb * N + n) * 2 + 1]) - scale * float(zero_shift); + for (int32_t idx = 0; idx < group_size && k < K; idx++, k++) { + const auto a_val = float(A_ptr[k]); + uint8_t w_val = w_ptr[n * K + k]; + rc += a_val * (scale * float(w_val) + zero); + } + } + + out_ptr[m * N + n] = T(rc); + } + } +} + +} // namespace lowbit +} // namespace mps +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/kernels/mps/src/packing.h b/torchao/experimental/kernels/mps/src/packing.h new file mode 100644 index 0000000000..b28d738f68 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/packing.h @@ -0,0 +1,192 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +namespace torchao { +namespace kernels { +namespace mps { +namespace lowbit { +namespace packing { + +template +inline void pack(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K); + +template <> +inline void +pack<1>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + for (int32_t k8 = 0; k8 < K / 8; k8++) { + uint8_t src_val0 = w_ptr[n * K + k8 * 8]; + uint8_t src_val1 = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_val2 = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_val3 = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_val4 = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_val5 = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_val6 = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_val7 = w_ptr[n * K + k8 * 8 + 7]; + b_ptr[n * (K / 8) + k8] = (uint8_t(src_val7) << 7) | + (uint8_t(src_val6) << 6) | (uint8_t(src_val5) << 5) | + (uint8_t(src_val4) << 4) | (uint8_t(src_val3) << 3) | + (uint8_t(src_val2) << 2) | (uint8_t(src_val1) << 1) | + uint8_t(src_val0); + } + } +} + +template <> +inline void +pack<2>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + for (int32_t k4 = 0; k4 < K / 4; k4++) { + uint8_t src_val0 = w_ptr[n * K + k4 * 4]; + uint8_t src_val1 = w_ptr[n * K + k4 * 4 + 1]; + uint8_t src_val2 = w_ptr[n * K + k4 * 4 + 2]; + uint8_t src_val3 = w_ptr[n * K + k4 * 4 + 3]; + b_ptr[n * (K / 4) + k4] = (uint8_t(src_val3) << 6) | + (uint8_t(src_val2) << 4) | (uint8_t(src_val1) << 2) | + uint8_t(src_val0); + } + } +} + +template <> +inline void +pack<3>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = n * 3 * K / 8; + for (int32_t k8 = 0; k8 < K / 8; k8++) { + uint8_t src_0ab = w_ptr[n * K + k8 * 8 + 0]; + uint8_t src_1cd = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_2ef = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_3gh = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_4ij = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_5kl = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_6mn = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_7op = w_ptr[n * K + k8 * 8 + 7]; + + // b0: 7|6|5|4|3|2|1|0 (upper bits for all values) + b_ptr[row_base + 3 * k8 + 0] = ((src_0ab & 4) >> 2) | + ((src_1cd & 4) >> 1) | ((src_2ef & 4)) | ((src_3gh & 4) << 1) | + ((src_4ij & 4) << 2) | ((src_5kl & 4) << 3) | ((src_6mn & 4) << 4) | + ((src_7op & 4) << 5); + + // b1: gh|ef|cd|ab (lower 2 bits for first 4 values) + b_ptr[row_base + 3 * k8 + 1] = (src_0ab & 3) | ((src_1cd & 3) << 2) | + ((src_2ef & 3) << 4) | ((src_3gh & 3) << 6); + + // b2: op|mn|kl|ij (lower 2 bits for last 4 values) + b_ptr[row_base + 3 * k8 + 2] = (src_4ij & 3) | ((src_5kl & 3) << 2) | + ((src_6mn & 3) << 4) | ((src_7op & 3) << 6); + } + } +} + +template <> +inline void +pack<4>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + for (int32_t k2 = 0; k2 < K / 2; k2++) { + uint8_t src_val0 = w_ptr[n * K + k2 * 2]; + uint8_t src_val1 = w_ptr[n * K + k2 * 2 + 1]; + b_ptr[n * (K / 2) + k2] = (uint8_t(src_val1) << 4) | uint8_t(src_val0); + } + } +} + +template <> +inline void +pack<5>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = n * 5 * K / 8; + for (int32_t k8 = 0; k8 < K / 8; k8++) { + uint8_t src_0abAB = w_ptr[n * K + k8 * 8 + 0]; + uint8_t src_1cdCD = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_2efEF = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_3ghGH = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_4ijIJ = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_5klKL = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_6mnMN = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_7opOP = w_ptr[n * K + k8 * 8 + 7]; + + // b0: 7|6|5|4|3|2|1|0 (upper bits for all values) + b_ptr[row_base + 5 * k8 + 0] = ((src_0abAB & 16) >> 4) | + ((src_1cdCD & 16) >> 3) | ((src_2efEF & 16) >> 2) | + ((src_3ghGH & 16) >> 1) | ((src_4ijIJ & 16)) | + ((src_5klKL & 16) << 1) | ((src_6mnMN & 16) << 2) | + ((src_7opOP & 16) << 3); + + // b1: cdCD|abAB (lower 4 bits for first 2 values) + b_ptr[row_base + 5 * k8 + 1] = (src_0abAB & 15) | ((src_1cdCD & 15) << 4); + + // b2: ghGH|efEF (lower 4 bits for second 2 values) + b_ptr[row_base + 5 * k8 + 2] = (src_2efEF & 15) | ((src_3ghGH & 15) << 4); + + // b3: klKL|ijIJ (lower 4 bits for third 2 values) + b_ptr[row_base + 5 * k8 + 3] = (src_4ijIJ & 15) | ((src_5klKL & 15) << 4); + + // b4: opOP|mnMN (lower 4 bits for last 2 values) + b_ptr[row_base + 5 * k8 + 4] = (src_6mnMN & 15) | ((src_7opOP & 15) << 4); + } + } +} + +template <> +inline void +pack<6>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = n * 3 * K / 4; + for (int32_t k4 = 0; k4 < K / 4; k4++) { + uint8_t src_10abcd = w_ptr[n * K + k4 * 4 + 0]; + uint8_t src_32efgh = w_ptr[n * K + k4 * 4 + 1]; + uint8_t src_54ijkl = w_ptr[n * K + k4 * 4 + 2]; + uint8_t src_76mnop = w_ptr[n * K + k4 * 4 + 3]; + + // b0: 76|54|32|10 (upper 2 bits for all values) + b_ptr[row_base + 3 * k4 + 0] = ((src_10abcd & 48) >> 4) | + ((src_32efgh & 48) >> 2) | ((src_54ijkl & 48)) | + ((src_76mnop & 48) << 2); + + // b1: efgh|abcd (lower 4 bits for first 2 values) + b_ptr[row_base + 3 * k4 + 1] = + (src_10abcd & 15) | ((src_32efgh & 15) << 4); + + // b2: mnop|ijkl (lower 4 bits for last 2 values) + b_ptr[row_base + 3 * k4 + 2] = + (src_54ijkl & 15) | ((src_76mnop & 15) << 4); + } + } +} + +template <> +inline void +pack<7>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = n * 7 * K / 8; + for (int32_t k8 = 0; k8 < K / 8; k8++) { + uint8_t src_0 = w_ptr[n * K + k8 * 8 + 0]; + uint8_t src_1 = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_2 = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_3 = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_4 = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_5 = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_6 = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_7 = w_ptr[n * K + k8 * 8 + 7]; + + b_ptr[row_base + 7 * k8 + 0] = src_0 | ((src_7 & 1) << 7); + b_ptr[row_base + 7 * k8 + 1] = src_1 | ((src_7 & 2) << 6); + b_ptr[row_base + 7 * k8 + 2] = src_2 | ((src_7 & 4) << 5); + b_ptr[row_base + 7 * k8 + 3] = src_3 | ((src_7 & 8) << 4); + b_ptr[row_base + 7 * k8 + 4] = src_4 | ((src_7 & 16) << 3); + b_ptr[row_base + 7 * k8 + 5] = src_5 | ((src_7 & 32) << 2); + b_ptr[row_base + 7 * k8 + 6] = src_6 | ((src_7 & 64) << 1); + } + } +} + +} // namespace packing +} // namespace lowbit +} // namespace mps +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/kernels/mps/test/Makefile b/torchao/experimental/kernels/mps/test/Makefile new file mode 100644 index 0000000000..e8213818c5 --- /dev/null +++ b/torchao/experimental/kernels/mps/test/Makefile @@ -0,0 +1,7 @@ +all: test_lowbit + +test_lowbit: test_lowbit.mm + clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $< -framework Metal -framework Foundation + +run: test_lowbit + ./test_lowbit diff --git a/torchao/experimental/kernels/mps/test/bfloat16.h b/torchao/experimental/kernels/mps/test/bfloat16.h new file mode 100644 index 0000000000..b041c1b492 --- /dev/null +++ b/torchao/experimental/kernels/mps/test/bfloat16.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +/** + * This implementation is copied from + * executorch/runtime/core/portable_type/bfloat16.h + */ + +inline float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; + std::memcpy(&res, &tmp, sizeof(tmp)); + return res; +} + +inline uint16_t bits_from_f32(float src) { + uint32_t res = 0; + std::memcpy(&res, &src, sizeof(res)); + return res >> 16; +} + +inline uint16_t round_to_nearest_even(float src) { + if (std::isnan(src)) { + return UINT16_C(0x7FC0); + } + uint32_t U32 = 0; + std::memcpy(&U32, &src, sizeof(U32)); + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); +} + +/** + * The "brain floating-point" type, compatible with c10/util/BFloat16.h from + * pytorch core. + * + * This representation uses 1 bit for the sign, 8 bits for the exponent and 7 + * bits for the mantissa. + */ +struct alignas(2) BFloat16 { + uint16_t x; + + BFloat16() = default; + struct from_bits_t {}; + static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + constexpr BFloat16(unsigned short bits, from_bits_t) : x(bits) {} + /* implicit */ BFloat16(float value) : x(round_to_nearest_even(value)) {} + operator float() const { + return f32_from_bits(x); + } +}; diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm new file mode 100644 index 0000000000..427fbceb39 --- /dev/null +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -0,0 +1,181 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +using Float16 = _Float16; + +template +const std::string_view type_string(); +template <> +const std::string_view type_string() { + return "bfloat"; +} +template <> +const std::string_view type_string() { + return "float"; +} +template <> +const std::string_view type_string() { + return "half"; +} + +inline id allocSharedBuffer(id device, unsigned length) { + id rc = [device newBufferWithLength:length + options:MTLResourceStorageModeShared]; + if (rc == nil) { + fail("Can't allocate " + std::to_string(length) + " bytes on GPU"); + } + return rc; +} + +namespace torchao { +namespace kernels { +namespace mps { +namespace lowbit { + +template +class LowBitTester { + public: + LowBitTester(int32_t m, int32_t k, int32_t n, int32_t group_size) + : M(m), K(k), N(n), qGroupSize(group_size) {} + + void init() { + allocBuffers(MTL_DEVICE); + + T* a_ptr = reinterpret_cast([buf_A contents]); + uint8_t* w_ptr = reinterpret_cast([buf_W contents]); + T* c_ptr = reinterpret_cast([buf_C contents]); + T* s_ptr = reinterpret_cast([buf_SZ contents]); + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution<> int_distrib(0, (1 << nbit) - 1); + std::uniform_real_distribution<> real_distrib(-1.0, 1.0); + + for (int idx = 0; idx < M * K; ++idx) { + a_ptr[idx] = real_distrib(generator); + } + for (int idx = 0; idx < N * K; ++idx) { + w_ptr[idx] = int_distrib(generator); + } + for (int idx = 0; idx < N * K / qGroupSize; ++idx) { + s_ptr[2 * idx] = (idx + 1.0) / N; + s_ptr[2 * idx + 1] = 0; + } + for (int idx = 0; idx < M * N; ++idx) { + c_ptr[idx] = -1.0; + } + } + + void pack() { + uint8_t* w_ptr = reinterpret_cast([buf_W contents]); + uint8_t* b_ptr = reinterpret_cast([buf_B contents]); + LowBitQuantWeights::pack(w_ptr, b_ptr, N, K); + } + + void linear() { + LowBitQuantWeights::linear( + buf_A, buf_B, qGroupSize, buf_SZ, buf_C, M, K, N, type_string()); + } + + bool validate(float atol_lim = 5e-3, float rtol_lim = 5e-3) const { + T* a_ptr = reinterpret_cast([buf_A contents]); + uint8_t* w_ptr = reinterpret_cast([buf_W contents]); + T* c_ptr = reinterpret_cast([buf_C contents]); + T* sz_ptr = reinterpret_cast([buf_SZ contents]); + + float* e_ptr_f = new float[M * N / 4 * sizeof(T)]; // expected + T* e_ptr = reinterpret_cast(e_ptr_f); + reference_linear_lowbit_quant_weights_cpu( + a_ptr, w_ptr, qGroupSize, sz_ptr, e_ptr, M, K, N, nbit); + + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + float rc = float(c_ptr[m * N + n]); + float expected = float(e_ptr[m * N + n]); + + auto atol = std::abs(rc - expected); + auto rtol = + atol / std::max(std::min(std::abs(expected), std::abs(rc)), 1e-6f); + if (rtol > rtol_lim && atol > atol_lim) { + std::cerr << "Result " << expected << " vs expected " << rc + << " (atol=" << atol << " ,rtol=" << rtol << ") at " << m + << ":" << n << std::endl; + return false; + } + } + } + return true; + } + + private: + void allocBuffers(id device) { + const int32_t elem_size = sizeof(T); + buf_A = allocSharedBuffer(device, M * K * elem_size); + buf_W = allocSharedBuffer(device, N * K); + buf_B = allocSharedBuffer(device, N * nbit * K / 8); + buf_C = allocSharedBuffer(device, M * N * elem_size); + buf_SZ = allocSharedBuffer(device, N * K / qGroupSize * 2 * elem_size); + } + + public: + int32_t M, K, N; // Input-output matirx dims + int32_t qGroupSize; + id buf_A; // MxK elements + id buf_W; // NxK elements + id buf_B; // NxK elements (packed) + id buf_C; // MxN elements + id buf_SZ; // (K/group_size)xNx2 elements +}; + +} // namespace lowbit +} // namespace mps +} // namespace kernels +} // namespace torchao + +template +void run_test() { + torchao::kernels::mps::lowbit::LowBitTester tester( + 256, 256, 256, 32); + tester.init(); + tester.pack(); + tester.linear(); + bool success = tester.validate(); + std::cout << "Test " << type_string() << " " << nbit << "-bit " + << (success ? "succeeded" : "failed") << std::endl; +} + +int main() { + run_test(); + run_test(); + run_test(); + run_test(); + run_test(); + run_test(); + run_test(); + + run_test(); + run_test(); + run_test(); + run_test(); + run_test(); + run_test(); + run_test(); + + run_test(); + run_test(); + run_test(); + run_test(); + run_test(); + run_test(); + run_test(); + + return 0; +} diff --git a/torchao/experimental/ops/mps/register.mm b/torchao/experimental/ops/mps/register.mm new file mode 100644 index 0000000000..4f5a5cd1b2 --- /dev/null +++ b/torchao/experimental/ops/mps/register.mm @@ -0,0 +1,114 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// clang-format off +#include +#include +#include +// clang-format on + +namespace torchao { +namespace kernels { +namespace mps { +namespace lowbit { +namespace aten { + +using Tensor = at::Tensor; +using namespace at::native::mps; + +// LowBit Quantized Linear on MPS Backend +template +Tensor linear_mps_kernel( + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& SZ) { + auto M = A.size(0); + auto N = B.size(0); + auto K = A.size(1); + auto C = at::empty({M, N}, A.options()); + + LowBitQuantWeights::linear( + getMTLBufferStorage(A), + getMTLBufferStorage(B), + group_size, + getMTLBufferStorage(SZ), + getMTLBufferStorage(C), + M, + K, + N, + scalarToMetalTypeString(A)); + + return C; +} + +// LowBit Packing on CPU Backend +template +Tensor pack_weights_cpu_kernel(const Tensor& W) { + auto N = W.size(0); + auto K = W.size(1); + auto B = at::empty({N, nbit * K / 8}, W.options()); + + uint8_t* w_ptr = W.data_ptr(); + uint8_t* b_ptr = B.data_ptr(); + + LowBitQuantWeights::pack(w_ptr, b_ptr, N, K); + + return B; +} + +// Registers _C as a Python extension module. +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} + +TORCH_LIBRARY(torchao, m) { + m.def("_pack_weight_1bit(Tensor W) -> Tensor"); + m.def("_pack_weight_2bit(Tensor W) -> Tensor"); + m.def("_pack_weight_3bit(Tensor W) -> Tensor"); + m.def("_pack_weight_4bit(Tensor W) -> Tensor"); + m.def("_pack_weight_5bit(Tensor W) -> Tensor"); + m.def("_pack_weight_6bit(Tensor W) -> Tensor"); + m.def("_pack_weight_7bit(Tensor W) -> Tensor"); + m.def( + "_linear_fp_act_1bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_2bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_3bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_4bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_5bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("_pack_weight_1bit", &pack_weights_cpu_kernel<1>); + m.impl("_pack_weight_2bit", &pack_weights_cpu_kernel<2>); + m.impl("_pack_weight_3bit", &pack_weights_cpu_kernel<3>); + m.impl("_pack_weight_4bit", &pack_weights_cpu_kernel<4>); + m.impl("_pack_weight_5bit", &pack_weights_cpu_kernel<5>); + m.impl("_pack_weight_6bit", &pack_weights_cpu_kernel<6>); + m.impl("_pack_weight_7bit", &pack_weights_cpu_kernel<7>); +} + +TORCH_LIBRARY_IMPL(torchao, MPS, m) { + m.impl("_linear_fp_act_1bit_weight", &linear_mps_kernel<1>); + m.impl("_linear_fp_act_2bit_weight", &linear_mps_kernel<2>); + m.impl("_linear_fp_act_3bit_weight", &linear_mps_kernel<3>); + m.impl("_linear_fp_act_4bit_weight", &linear_mps_kernel<4>); + m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel<5>); + m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel<6>); + m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>); +} + +} // namespace aten +} // namespace lowbit +} // namespace mps +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/ops/mps/setup.py b/torchao/experimental/ops/mps/setup.py new file mode 100644 index 0000000000..e9c206cdb9 --- /dev/null +++ b/torchao/experimental/ops/mps/setup.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from setuptools import setup +from torch.utils.cpp_extension import CppExtension, BuildExtension + +setup( + name="torchao_mps_ops", + version="1.0", + ext_modules=[ + CppExtension( + name="torchao_mps_ops", + sources=["register.mm"], + include_dirs=[os.getenv("TORCHAO_ROOT")], + extra_compile_args=["-DATEN=1"], + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py new file mode 100644 index 0000000000..a401eb60a5 --- /dev/null +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torchao_mps_ops +import unittest + + +def parameterized(test_cases): + def decorator(func): + def wrapper(self): + for case in test_cases: + with self.subTest(case=case): + func(self, *case) + + return wrapper + + return decorator + + +class TestLowBitQuantWeightsLinear(unittest.TestCase): + cases = [ + (nbit, *param) + for nbit in range(1, 8) + for param in [ + (1, 32, 32, 32), + (128, 48, 64, 32), + (17, 1024, 512, 256), + ] + ] + + def _init_tensors(self, group_size, M, K, N, nbit, device="mps"): + max_abs = 1 << (nbit - 1) + ceil_K_group_size = (K + group_size - 1) // group_size + A = 2 * torch.rand(M, K, dtype=torch.float32, device=device) - 1 + W = torch.randint(0, 2 * max_abs, (N, K), dtype=torch.uint8, device=device) + S = torch.rand(ceil_K_group_size, N, dtype=torch.float32, device=device) + 0.01 + Z = torch.randint( + -max_abs, + max_abs, + (ceil_K_group_size, N), + dtype=torch.float32, + device=device, + ) + SZ = torch.stack((S, Z), dim=2) + return A, W, SZ + + def _reference_linear_lowbit_quant_weights(self, A, W, group_size, SZ, nbit): + # A is (M, K) + # W is (N, K) + # SZ is (K // group_size, N, 2) + N = W.shape[0] + K = W.shape[1] + max_abs = 1 << (nbit - 1) + W = W.to(torch.float32) - max_abs + scales = ( + SZ[:, :, 0].t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + ) + zeros = SZ[:, :, 1].t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + W = scales * W + zeros + return torch.mm(A, W.t()) + + @parameterized(cases) + def test_linear(self, nbit, M=1, K=32, N=32, group_size=32): + print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}") + A, W, SZ = self._init_tensors(group_size, M, K, N, nbit=nbit) + packing_op = getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") + linear_op = getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + B = packing_op(W.cpu()).to("mps") + result = linear_op(A, B, group_size, SZ).cpu() + expected = self._reference_linear_lowbit_quant_weights( + A.cpu(), W.cpu(), group_size, SZ.cpu(), nbit=nbit + ) + max_err = torch.max(torch.abs(result - expected)) + self.assertLess(max_err, 0.001) + + +if __name__ == "__main__": + unittest.main()