From 6c16c82f7408f9ae260178655c93645cd80708b5 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 10 Oct 2024 05:30:32 -0700 Subject: [PATCH] Introduce lowbit quantized linear MPS kernels (#954) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The following is the directory structure of the submitted code under torchao ``` experimental/ ├── kernels/ │ └── mps/ │ ├── metal/ │ │ └── (metal shaders) │ ├── src/ │ │ └── (tensor agnostic mps kernel implementations) │ └── test/ │ │ └── (directly test mps kernel implementations) └── ops/ └── mps/ ├── register.mm ├── setup.py └── test/ └── (test torch custom ops) ``` Differential Revision: D63342895 --- .../kernels/mps/metal/divbit.metal | 106 ++++++++ .../kernels/mps/metal/int3mm.metal | 97 +++++++ .../kernels/mps/metal/int5mm.metal | 99 +++++++ .../kernels/mps/metal/int6mm.metal | 84 ++++++ .../kernels/mps/metal/int7mm.metal | 101 ++++++++ .../kernels/mps/src/OperationUtils.h | 142 +++++++++++ .../experimental/kernels/mps/src/dispatch.h | 32 +++ torchao/experimental/kernels/mps/src/lowbit.h | 183 +++++++++++++ .../experimental/kernels/mps/src/packing.h | 239 +++++++++++++++++ .../experimental/kernels/mps/test/Makefile | 7 + .../experimental/kernels/mps/test/bfloat16.h | 61 +++++ .../kernels/mps/test/test_lowbit.mm | 241 ++++++++++++++++++ torchao/experimental/ops/mps/register.mm | 145 +++++++++++ torchao/experimental/ops/mps/setup.py | 23 ++ .../experimental/ops/mps/test/test_lowbit.py | 81 ++++++ 15 files changed, 1641 insertions(+) create mode 100644 torchao/experimental/kernels/mps/metal/divbit.metal create mode 100644 torchao/experimental/kernels/mps/metal/int3mm.metal create mode 100644 torchao/experimental/kernels/mps/metal/int5mm.metal create mode 100644 torchao/experimental/kernels/mps/metal/int6mm.metal create mode 100644 torchao/experimental/kernels/mps/metal/int7mm.metal create mode 100644 torchao/experimental/kernels/mps/src/OperationUtils.h create mode 100644 torchao/experimental/kernels/mps/src/dispatch.h create mode 100644 torchao/experimental/kernels/mps/src/lowbit.h create mode 100644 torchao/experimental/kernels/mps/src/packing.h create mode 100644 torchao/experimental/kernels/mps/test/Makefile create mode 100644 torchao/experimental/kernels/mps/test/bfloat16.h create mode 100644 torchao/experimental/kernels/mps/test/test_lowbit.mm create mode 100644 torchao/experimental/ops/mps/register.mm create mode 100644 torchao/experimental/ops/mps/setup.py create mode 100644 torchao/experimental/ops/mps/test/test_lowbit.py diff --git a/torchao/experimental/kernels/mps/metal/divbit.metal b/torchao/experimental/kernels/mps/metal/divbit.metal new file mode 100644 index 0000000000..68f5f7dc03 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/divbit.metal @@ -0,0 +1,106 @@ +#include +using namespace metal; + +/** + * LowBit Quantized Linear for bitwidths that are divisors of 8. Hence the name. + * + * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) + * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8) + * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 + * @param[outputData] M x N output tensor of floating point dtype (same as input) + * @param[sizes] The sizes involved in the order: M, K, N + * + * Dispatched threads: N x M x 1 + */ +template +kernel void divbit_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B; + + constexpr uint8_t zero_shift = 1 << (nbit - 1); + constexpr uint8_t values_per_byte = 8 / nbit; + constexpr uint8_t minimask = (1 << nbit) - 1; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const T scale = scalesAndZeros[(kb * N + n) * 2 + 0]; + const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift); + for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { + const auto a_val = float(A_ptr[k]); + uint8_t b_val = B_ptr[(n * K + k) / values_per_byte]; + uint8_t shift = nbit * (k % values_per_byte); + uint8_t mask = minimask << shift; + b_val = (b_val & mask) >> shift; + rc += a_val * float(scale * T(b_val) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \ +template \ +[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void divbit_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_DIVBIT_MM(1, float, 32); +INSTANTIATE_DIVBIT_MM(1, half, 32); +INSTANTIATE_DIVBIT_MM(1, float, 64); +INSTANTIATE_DIVBIT_MM(1, half, 64); +INSTANTIATE_DIVBIT_MM(1, float, 128); +INSTANTIATE_DIVBIT_MM(1, half, 128); +INSTANTIATE_DIVBIT_MM(1, float, 256); +INSTANTIATE_DIVBIT_MM(1, half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_DIVBIT_MM(1, bfloat, 32); +INSTANTIATE_DIVBIT_MM(1, bfloat, 64); +INSTANTIATE_DIVBIT_MM(1, bfloat, 128); +INSTANTIATE_DIVBIT_MM(1, bfloat, 256); +#endif + +INSTANTIATE_DIVBIT_MM(2, float, 32); +INSTANTIATE_DIVBIT_MM(2, half, 32); +INSTANTIATE_DIVBIT_MM(2, float, 64); +INSTANTIATE_DIVBIT_MM(2, half, 64); +INSTANTIATE_DIVBIT_MM(2, float, 128); +INSTANTIATE_DIVBIT_MM(2, half, 128); +INSTANTIATE_DIVBIT_MM(2, float, 256); +INSTANTIATE_DIVBIT_MM(2, half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_DIVBIT_MM(2, bfloat, 32); +INSTANTIATE_DIVBIT_MM(2, bfloat, 64); +INSTANTIATE_DIVBIT_MM(2, bfloat, 128); +INSTANTIATE_DIVBIT_MM(2, bfloat, 256); +#endif + +INSTANTIATE_DIVBIT_MM(4, float, 32); +INSTANTIATE_DIVBIT_MM(4, half, 32); +INSTANTIATE_DIVBIT_MM(4, float, 64); +INSTANTIATE_DIVBIT_MM(4, half, 64); +INSTANTIATE_DIVBIT_MM(4, float, 128); +INSTANTIATE_DIVBIT_MM(4, half, 128); +INSTANTIATE_DIVBIT_MM(4, float, 256); +INSTANTIATE_DIVBIT_MM(4, half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_DIVBIT_MM(4, bfloat, 32); +INSTANTIATE_DIVBIT_MM(4, bfloat, 64); +INSTANTIATE_DIVBIT_MM(4, bfloat, 128); +INSTANTIATE_DIVBIT_MM(4, bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int3mm.metal b/torchao/experimental/kernels/mps/metal/int3mm.metal new file mode 100644 index 0000000000..8fd68cd768 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int3mm.metal @@ -0,0 +1,97 @@ +#include +using namespace metal; + +/** + * 3-Bit Quantized Linear. + * + * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) + * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8) + * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 + * @param[outputData] M x N output tensor of floating point dtype (same as input) + * @param[sizes] The sizes involved in the order: M, K, N + * + * Dispatched threads: N x M x 1 + */ +template +kernel void int3pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B + n * 3 * K / 8; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4); + for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { + const auto a_val0 = float(A_ptr[k + 0]); + const auto a_val1 = float(A_ptr[k + 1]); + const auto a_val2 = float(A_ptr[k + 2]); + const auto a_val3 = float(A_ptr[k + 3]); + const auto a_val4 = float(A_ptr[k + 4]); + const auto a_val5 = float(A_ptr[k + 5]); + const auto a_val6 = float(A_ptr[k + 6]); + const auto a_val7 = float(A_ptr[k + 7]); + + uchar b0 = B_ptr[3 * (k / 8) + 0]; + uchar b1 = B_ptr[3 * (k / 8) + 1]; + uchar b2 = B_ptr[3 * (k / 8) + 2]; + + uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3); + uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2); + uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4); + uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6); + + uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3); + uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2); + uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4); + uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6); + + rc += a_val0 * (scale * float(w_val0) + zero); + rc += a_val1 * (scale * float(w_val1) + zero); + rc += a_val2 * (scale * float(w_val2) + zero); + rc += a_val3 * (scale * float(w_val3) + zero); + rc += a_val4 * (scale * float(w_val4) + zero); + rc += a_val5 * (scale * float(w_val5) + zero); + rc += a_val6 * (scale * float(w_val6) + zero); + rc += a_val7 * (scale * float(w_val7) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \ +template \ +[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int3pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT3MM(float, 32); +INSTANTIATE_INT3MM(half, 32); +INSTANTIATE_INT3MM(float, 64); +INSTANTIATE_INT3MM(half, 64); +INSTANTIATE_INT3MM(float, 128); +INSTANTIATE_INT3MM(half, 128); +INSTANTIATE_INT3MM(float, 256); +INSTANTIATE_INT3MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT3MM(bfloat, 32); +INSTANTIATE_INT3MM(bfloat, 64); +INSTANTIATE_INT3MM(bfloat, 128); +INSTANTIATE_INT3MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int5mm.metal b/torchao/experimental/kernels/mps/metal/int5mm.metal new file mode 100644 index 0000000000..84aba20725 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int5mm.metal @@ -0,0 +1,99 @@ +#include +using namespace metal; + +/** + * 5-Bit Quantized Linear. + * + * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) + * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8) + * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 + * @param[outputData] M x N output tensor of floating point dtype (same as input) + * @param[sizes] The sizes involved in the order: M, K, N + * + * Dispatched threads: N x M x 1 + */ +template +kernel void int5pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B + n * 5 * K / 8; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(16); + for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { + const auto a_val0 = float(A_ptr[k + 0]); + const auto a_val1 = float(A_ptr[k + 1]); + const auto a_val2 = float(A_ptr[k + 2]); + const auto a_val3 = float(A_ptr[k + 3]); + const auto a_val4 = float(A_ptr[k + 4]); + const auto a_val5 = float(A_ptr[k + 5]); + const auto a_val6 = float(A_ptr[k + 6]); + const auto a_val7 = float(A_ptr[k + 7]); + + uchar b0 = B_ptr[5 * (k / 8) + 0]; + uchar b1 = B_ptr[5 * (k / 8) + 1]; + uchar b2 = B_ptr[5 * (k / 8) + 2]; + uchar b3 = B_ptr[5 * (k / 8) + 3]; + uchar b4 = B_ptr[5 * (k / 8) + 4]; + + uchar w_val0 = ((b0 & 1) << 4) | (b1 & 15); + uchar w_val1 = ((b0 & 2) << 3) | ((b1 & 240) >> 4); + uchar w_val2 = ((b0 & 4) << 2) | (b2 & 15); + uchar w_val3 = ((b0 & 8) << 1) | ((b2 & 240) >> 4); + + uchar w_val4 = ((b0 & 16)) | (b3 & 15); + uchar w_val5 = ((b0 & 32) >> 1) | ((b3 & 240) >> 4); + uchar w_val6 = ((b0 & 64) >> 2) | (b4 & 15); + uchar w_val7 = ((b0 & 128) >> 3) | ((b4 & 240) >> 4); + + rc += a_val0 * (scale * float(w_val0) + zero); + rc += a_val1 * (scale * float(w_val1) + zero); + rc += a_val2 * (scale * float(w_val2) + zero); + rc += a_val3 * (scale * float(w_val3) + zero); + rc += a_val4 * (scale * float(w_val4) + zero); + rc += a_val5 * (scale * float(w_val5) + zero); + rc += a_val6 * (scale * float(w_val6) + zero); + rc += a_val7 * (scale * float(w_val7) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT5MM(DTYPE, GSIZE) \ +template \ +[[host_name("int5pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int5pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT5MM(float, 32); +INSTANTIATE_INT5MM(half, 32); +INSTANTIATE_INT5MM(float, 64); +INSTANTIATE_INT5MM(half, 64); +INSTANTIATE_INT5MM(float, 128); +INSTANTIATE_INT5MM(half, 128); +INSTANTIATE_INT5MM(float, 256); +INSTANTIATE_INT5MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT5MM(bfloat, 32); +INSTANTIATE_INT5MM(bfloat, 64); +INSTANTIATE_INT5MM(bfloat, 128); +INSTANTIATE_INT5MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int6mm.metal b/torchao/experimental/kernels/mps/metal/int6mm.metal new file mode 100644 index 0000000000..7b99b749e7 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int6mm.metal @@ -0,0 +1,84 @@ +#include +using namespace metal; + +/** + * 6-Bit Quantized Linear. + * + * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) + * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (6 * K / 8) + * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 + * @param[outputData] M x N output tensor of floating point dtype (same as input) + * @param[sizes] The sizes involved in the order: M, K, N + * + * Dispatched threads: N x M x 1 + */ +template +kernel void int6pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B + n * 3 * K / 4; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(32); + for(uint idx = 0; idx < groupSize && k < K; idx+=4, k+=4) { + const auto a_val0 = float(A_ptr[k + 0]); + const auto a_val1 = float(A_ptr[k + 1]); + const auto a_val2 = float(A_ptr[k + 2]); + const auto a_val3 = float(A_ptr[k + 3]); + + uchar b0 = B_ptr[3 * (k / 4) + 0]; + uchar b1 = B_ptr[3 * (k / 4) + 1]; + uchar b2 = B_ptr[3 * (k / 4) + 2]; + + uchar w_val0 = ((b0 & 3) << 4) | (b1 & 15); + uchar w_val1 = ((b0 & 12) << 2) | ((b1 & 240) >> 4); + uchar w_val2 = ((b0 & 48)) | (b2 & 15); + uchar w_val3 = ((b0 & 192) >> 2) | ((b2 & 240) >> 4); + + rc += a_val0 * (scale * float(w_val0) + zero); + rc += a_val1 * (scale * float(w_val1) + zero); + rc += a_val2 * (scale * float(w_val2) + zero); + rc += a_val3 * (scale * float(w_val3) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT6MM(DTYPE, GSIZE) \ +template \ +[[host_name("int6pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int6pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT6MM(float, 32); +INSTANTIATE_INT6MM(half, 32); +INSTANTIATE_INT6MM(float, 64); +INSTANTIATE_INT6MM(half, 64); +INSTANTIATE_INT6MM(float, 128); +INSTANTIATE_INT6MM(half, 128); +INSTANTIATE_INT6MM(float, 256); +INSTANTIATE_INT6MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT6MM(bfloat, 32); +INSTANTIATE_INT6MM(bfloat, 64); +INSTANTIATE_INT6MM(bfloat, 128); +INSTANTIATE_INT6MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int7mm.metal b/torchao/experimental/kernels/mps/metal/int7mm.metal new file mode 100644 index 0000000000..bcd03f50f7 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int7mm.metal @@ -0,0 +1,101 @@ +#include +using namespace metal; + +/** + * 7-Bit Quantized Linear. + * + * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) + * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (7 * K / 8) + * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 + * @param[outputData] M x N output tensor of floating point dtype (same as input) + * @param[sizes] The sizes involved in the order: M, K, N + * + * Dispatched threads: N x M x 1 + */ +template +kernel void int7pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 thread_index [[thread_position_in_grid]]) { + const uint K = sizes.y; + const uint N = sizes.z; + const uint m = thread_index.y; // 0..M-1 + const uint n = thread_index.x; // 0..N-1 + const uint32_t k_block = (K + groupSize - 1) / groupSize; + constant T *A_ptr = A + m * K; + constant uchar *B_ptr = B + n * 7 * K / 8; + + float rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); + const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(64); + for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { + const auto a_val0 = float(A_ptr[k + 0]); + const auto a_val1 = float(A_ptr[k + 1]); + const auto a_val2 = float(A_ptr[k + 2]); + const auto a_val3 = float(A_ptr[k + 3]); + const auto a_val4 = float(A_ptr[k + 4]); + const auto a_val5 = float(A_ptr[k + 5]); + const auto a_val6 = float(A_ptr[k + 6]); + const auto a_val7 = float(A_ptr[k + 7]); + + uchar b0 = B_ptr[7 * (k / 8) + 0]; + uchar b1 = B_ptr[7 * (k / 8) + 1]; + uchar b2 = B_ptr[7 * (k / 8) + 2]; + uchar b3 = B_ptr[7 * (k / 8) + 3]; + uchar b4 = B_ptr[7 * (k / 8) + 4]; + uchar b5 = B_ptr[7 * (k / 8) + 5]; + uchar b6 = B_ptr[7 * (k / 8) + 6]; + + uchar w_val0 = b0 & 127; + uchar w_val1 = b1 & 127; + uchar w_val2 = b2 & 127; + uchar w_val3 = b3 & 127; + uchar w_val4 = b4 & 127; + uchar w_val5 = b5 & 127; + uchar w_val6 = b6 & 127; + uchar w_val7 = ((b0 & 128) >> 7) | ((b1 & 128) >> 6) | ((b2 & 128) >> 5) | ((b3 & 128) >> 4) + | ((b4 & 128) >> 3) | ((b5 & 128) >> 2) | ((b6 & 128) >> 1); + + rc += a_val0 * (scale * float(w_val0) + zero); + rc += a_val1 * (scale * float(w_val1) + zero); + rc += a_val2 * (scale * float(w_val2) + zero); + rc += a_val3 * (scale * float(w_val3) + zero); + rc += a_val4 * (scale * float(w_val4) + zero); + rc += a_val5 * (scale * float(w_val5) + zero); + rc += a_val6 * (scale * float(w_val6) + zero); + rc += a_val7 * (scale * float(w_val7) + zero); + } + } + outputData[m * N + n] = T(rc); +} + +#define INSTANTIATE_INT7MM(DTYPE, GSIZE) \ +template \ +[[host_name("int7pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int7pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_INT7MM(float, 32); +INSTANTIATE_INT7MM(half, 32); +INSTANTIATE_INT7MM(float, 64); +INSTANTIATE_INT7MM(half, 64); +INSTANTIATE_INT7MM(float, 128); +INSTANTIATE_INT7MM(half, 128); +INSTANTIATE_INT7MM(float, 256); +INSTANTIATE_INT7MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT7MM(bfloat, 32); +INSTANTIATE_INT7MM(bfloat, 64); +INSTANTIATE_INT7MM(bfloat, 128); +INSTANTIATE_INT7MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/src/OperationUtils.h b/torchao/experimental/kernels/mps/src/OperationUtils.h new file mode 100644 index 0000000000..80857b1302 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/OperationUtils.h @@ -0,0 +1,142 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +static void fail(const std::string& str) { + std::cerr << str << std::endl; + abort(); +} + +static void fail(const std::string& str1, const std::string& str2) { + std::cerr << str1 << str2 << std::endl; + abort(); +} + +inline void dispatch_sync_with_rethrow( + id queue, + void (^block)()) { + (void)queue; + block(); +} + +inline id getMetalDevice() { + NSArray* devices = [MTLCopyAllDevices() autorelease]; + if (devices.count == 0) { + fail("Metal is not supported"); + } + return devices[0]; +} + +static id MTL_DEVICE = getMetalDevice(); + +static id compileLibraryFromSource( + id device, + const std::string& source) { + NSError* error = nil; + MTLCompileOptions* options = [MTLCompileOptions new]; + [options setLanguageVersion:MTLLanguageVersion3_1]; + NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; + id library = [device newLibraryWithSource:kernel_source + options:options + error:&error]; + if (library == nil) { + fail("Failed to compile: ", error.description.UTF8String); + } + return library; +} + +class MetalShaderLibrary { + public: + MetalShaderLibrary(const std::string& src) : shaderSource(src) { + lib = compileLibraryFromSource(device, shaderSource); + } + MetalShaderLibrary(const MetalShaderLibrary&) = delete; + MetalShaderLibrary(MetalShaderLibrary&&) = delete; + + id getPipelineStateForFunc( + const std::string& fname) { + return get_compute_pipeline_state(load_func(fname)); + } + + private: + std::string shaderSource; + id device = MTL_DEVICE; + id lib = nil; + + id load_func(const std::string& func_name) const { + id func = [lib + newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; + if (func == nil) { + fail("Can't get function:" + func_name); + } + return func; + } + + id get_compute_pipeline_state( + id func) const { + NSError* error = nil; + auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; + if (cpl == nil) { + fail( + "Failed to construct pipeline state: ", error.description.UTF8String); + } + return cpl; + } +}; + +class MPSStream { + public: + MPSStream() { + _commandQueue = [MTL_DEVICE newCommandQueue]; + } + + ~MPSStream() { + [_commandQueue release]; + _commandQueue = nil; + + assert(_commandBuffer == nil); + } + + id queue() const { + return _commandQueue; + } + + id commandBuffer() { + if (!_commandBuffer) { + auto desc = [MTLCommandBufferDescriptor new]; + desc.errorOptions = MTLCommandBufferErrorOptionEncoderExecutionStatus; + _commandBuffer = [_commandQueue commandBufferWithDescriptor:desc]; + } + return _commandBuffer; + } + + id commandEncoder() { + if (!_commandEncoder) { + _commandEncoder = [commandBuffer() computeCommandEncoder]; + } + return _commandEncoder; + } + + private: + id _commandQueue = nil; + id _commandBuffer = nil; + id _commandEncoder = nil; +}; + +inline void finalize_block(MPSStream* mpsStream) { + id encoder = mpsStream->commandEncoder(); + id cmdBuffer = mpsStream->commandBuffer(); + [encoder endEncoding]; + [cmdBuffer commit]; + [cmdBuffer waitUntilCompleted]; +} + +inline MPSStream* getCurrentMPSStream() { + return new MPSStream(); +} diff --git a/torchao/experimental/kernels/mps/src/dispatch.h b/torchao/experimental/kernels/mps/src/dispatch.h new file mode 100644 index 0000000000..1e98149e54 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/dispatch.h @@ -0,0 +1,32 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +namespace torchao { +namespace kernels { +namespace mps { +namespace lowbit { +namespace dispatch { + +inline void dispatch_mm( + id encoder, + int32_t maxThreadsPerGroup, + int32_t M, + int32_t N, + int32_t K) { + (void)K; + [encoder dispatchThreads:MTLSizeMake(N, M, 1) + threadsPerThreadgroup:MTLSizeMake(std::min(maxThreadsPerGroup, M), 1, 1)]; +} + +} // namespace dispatch +} // namespace lowbit +} // namespace mps +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h new file mode 100644 index 0000000000..c7f260e0f3 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -0,0 +1,183 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +#include +#include + +#include +#include + +#ifdef ATEN +#include +using namespace at::native::mps; +inline void finalize_block(MPSStream* mpsStream) {} +#else +#include +#endif + +namespace torchao { +namespace kernels { +namespace mps { +namespace lowbit { +namespace { + +static constexpr std::string_view METAL_SHADER_DIR = + "/Users/mcandales/fbsource/xplat/pytorch/ao/torchao/experimental/kernels/mps/metal"; + +template +struct LowBitConfig {}; + +template <> +struct LowBitConfig<1> { + static constexpr std::string_view metal_filename = "divbit.metal"; + static constexpr std::string_view func_prefix = "int1pack_mm_"; + static constexpr auto packing_fn = packing::pack<1>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<2> { + static constexpr std::string_view metal_filename = "divbit.metal"; + static constexpr std::string_view func_prefix = "int2pack_mm_"; + static constexpr auto packing_fn = packing::pack<2>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<3> { + static constexpr std::string_view metal_filename = "int3mm.metal"; + static constexpr std::string_view func_prefix = "int3pack_mm_"; + static constexpr auto packing_fn = packing::pack<3>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<4> { + static constexpr std::string_view metal_filename = "divbit.metal"; + static constexpr std::string_view func_prefix = "int4pack_mm_"; + static constexpr auto packing_fn = packing::pack<4>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<5> { + static constexpr std::string_view metal_filename = "int5mm.metal"; + static constexpr std::string_view func_prefix = "int5pack_mm_"; + static constexpr auto packing_fn = packing::pack<5>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<6> { + static constexpr std::string_view metal_filename = "int6mm.metal"; + static constexpr std::string_view func_prefix = "int6pack_mm_"; + static constexpr auto packing_fn = packing::pack<6>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +template <> +struct LowBitConfig<7> { + static constexpr std::string_view metal_filename = "int7mm.metal"; + static constexpr std::string_view func_prefix = "int7pack_mm_"; + static constexpr auto packing_fn = packing::pack<7>; + static constexpr auto dispatch_fn = dispatch::dispatch_mm; +}; + +inline MetalShaderLibrary compileLibraryFromFile(const std::string& fname) { + std::ifstream ifs(fname); + std::stringstream ss; + ss << ifs.rdbuf(); + return MetalShaderLibrary(ss.str()); +} + +using DispatchFn = void (*)(id, int32_t, int32_t, int32_t, int32_t); + +inline void linear_lowbit_quant_weights_mps_impl( + id a_buf, + id b_buf, + id sz_buf, + id out_buf, + int32_t M, + int32_t K, + int32_t N, + const std::string shader_path, + const std::string shader_func, + DispatchFn dispatch_fn) { + std::array sizes = { + static_cast(M), + static_cast(K), + static_cast(N), + 0}; + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MetalShaderLibrary lib = compileLibraryFromFile(shader_path); + id cpl = + lib.getPipelineStateForFunc(shader_func); + const auto maxThreadsPerGroup = [cpl maxTotalThreadsPerThreadgroup]; + [computeEncoder setComputePipelineState:cpl]; + [computeEncoder setBuffer:a_buf offset:0 atIndex:0]; + [computeEncoder setBuffer:b_buf offset:0 atIndex:1]; + [computeEncoder setBuffer:sz_buf offset:0 atIndex:2]; + [computeEncoder setBuffer:out_buf offset:0 atIndex:3]; + [computeEncoder setBytes:sizes.data() + length:sizeof(uint32_t) * sizes.size() + atIndex:4]; + dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K); + finalize_block(mpsStream); + } + }); +} + +// LowBit Quantized Weights Linear on Metal +template +void linear_lowbit_quant_weights_mps( + id a_buf, + id b_buf, + int64_t qGroupSize, + id sz_buf, + id out_buf, + int32_t M, + int32_t K, + int32_t N, + const std::string_view type_str) { + const std::string shader_path = std::string(METAL_SHADER_DIR) + "/" + + std::string(LowBitConfig::metal_filename); + const std::string shader_func = std::string(LowBitConfig::func_prefix) + + std::to_string(qGroupSize) + "_" + std::string(type_str); + return linear_lowbit_quant_weights_mps_impl( + a_buf, + b_buf, + sz_buf, + out_buf, + M, + K, + N, + shader_path, + shader_func, + LowBitConfig::dispatch_fn); +} + +} // namespace + +// LowBit Quantized Weights Linear & Packing on Metal +template +struct LowBitQuantWeights { + static constexpr auto linear = linear_lowbit_quant_weights_mps; + static constexpr auto pack = LowBitConfig::packing_fn; +}; + +} // namespace lowbit +} // namespace mps +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/kernels/mps/src/packing.h b/torchao/experimental/kernels/mps/src/packing.h new file mode 100644 index 0000000000..276f876260 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/packing.h @@ -0,0 +1,239 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +namespace torchao { +namespace kernels { +namespace mps { +namespace lowbit { +namespace packing { + +/** + * Pack weights into a smaller number of bits. + * + * @param[in] w_ptr The input weight tensor. + * @param[out] b_ptr The output packed weight tensor. + * @param[in] N The number of rows in the weight matrix. + * @param[in] K The number of columns in the weight matrix. + */ +template +inline void pack(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K); + +/** + * All packing functions are implemented here. All of them pack the weights + * along the K dimension. + */ + +/** + * 1-bit packing. Each weight is a single bit, so we pack 8 weights into a byte. + */ +template <> +inline void +pack<1>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = n * (K / 8); + for (int32_t k8 = 0; k8 < K / 8; k8++) { + uint8_t src_val0 = w_ptr[n * K + k8 * 8]; + uint8_t src_val1 = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_val2 = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_val3 = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_val4 = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_val5 = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_val6 = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_val7 = w_ptr[n * K + k8 * 8 + 7]; + b_ptr[row_base + k8] = (uint8_t(src_val7) << 7) | + (uint8_t(src_val6) << 6) | (uint8_t(src_val5) << 5) | + (uint8_t(src_val4) << 4) | (uint8_t(src_val3) << 3) | + (uint8_t(src_val2) << 2) | (uint8_t(src_val1) << 1) | + uint8_t(src_val0); + } + } +} + +/** + * 2-bit packing. Each weight is two bits, so we pack 4 weights into a byte. + */ +template <> +inline void +pack<2>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = n * (K / 4); + for (int32_t k4 = 0; k4 < K / 4; k4++) { + uint8_t src_val0 = w_ptr[n * K + k4 * 4]; + uint8_t src_val1 = w_ptr[n * K + k4 * 4 + 1]; + uint8_t src_val2 = w_ptr[n * K + k4 * 4 + 2]; + uint8_t src_val3 = w_ptr[n * K + k4 * 4 + 3]; + b_ptr[row_base + k4] = (uint8_t(src_val3) << 6) | + (uint8_t(src_val2) << 4) | (uint8_t(src_val1) << 2) | + uint8_t(src_val0); + } + } +} + +/** + * 3-bit packing. Each weight is 3 bits. We can't pack them into a byte, so we + * pack 8 weights into 3 bytes. But we can't nicely pack the 8 weights + * continuously. Instead, we pack the upper bits of all weights into the first + * byte, then the 2 lower bits of all weights into the other 2 bytes. + */ +template <> +inline void +pack<3>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = (n * (K / 8)) * 3; + for (int32_t k8 = 0; k8 < K / 8; k8++) { + uint8_t src_0ab = w_ptr[n * K + k8 * 8 + 0]; + uint8_t src_1cd = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_2ef = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_3gh = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_4ij = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_5kl = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_6mn = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_7op = w_ptr[n * K + k8 * 8 + 7]; + + // b0: 7|6|5|4|3|2|1|0 (upper bits for all values) + b_ptr[row_base + 3 * k8 + 0] = ((src_0ab & 4) >> 2) | + ((src_1cd & 4) >> 1) | ((src_2ef & 4)) | ((src_3gh & 4) << 1) | + ((src_4ij & 4) << 2) | ((src_5kl & 4) << 3) | ((src_6mn & 4) << 4) | + ((src_7op & 4) << 5); + + // b1: gh|ef|cd|ab (lower 2 bits for first 4 values) + b_ptr[row_base + 3 * k8 + 1] = (src_0ab & 3) | ((src_1cd & 3) << 2) | + ((src_2ef & 3) << 4) | ((src_3gh & 3) << 6); + + // b2: op|mn|kl|ij (lower 2 bits for last 4 values) + b_ptr[row_base + 3 * k8 + 2] = (src_4ij & 3) | ((src_5kl & 3) << 2) | + ((src_6mn & 3) << 4) | ((src_7op & 3) << 6); + } + } +} + +/** + * 4-bit packing. Each weight is four bits, so we pack 2 weights into a byte. + */ +template <> +inline void +pack<4>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = n * (K / 2); + for (int32_t k2 = 0; k2 < K / 2; k2++) { + uint8_t src_val0 = w_ptr[n * K + k2 * 2]; + uint8_t src_val1 = w_ptr[n * K + k2 * 2 + 1]; + b_ptr[row_base + k2] = (uint8_t(src_val1) << 4) | uint8_t(src_val0); + } + } +} + +/** + * 5-bit packing. Each weight is 5 bits. So we pack 8 weights into 5 bytes. We + * pack the upper bits of all weights into the first byte, then the 4 lower + * bits of all weights into the other 4 bytes. + */ +template <> +inline void +pack<5>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = (n * (K / 8)) * 5; + for (int32_t k8 = 0; k8 < K / 8; k8++) { + uint8_t src_0abAB = w_ptr[n * K + k8 * 8 + 0]; + uint8_t src_1cdCD = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_2efEF = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_3ghGH = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_4ijIJ = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_5klKL = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_6mnMN = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_7opOP = w_ptr[n * K + k8 * 8 + 7]; + + // b0: 7|6|5|4|3|2|1|0 (upper bits for all values) + b_ptr[row_base + 5 * k8 + 0] = ((src_0abAB & 16) >> 4) | + ((src_1cdCD & 16) >> 3) | ((src_2efEF & 16) >> 2) | + ((src_3ghGH & 16) >> 1) | ((src_4ijIJ & 16)) | + ((src_5klKL & 16) << 1) | ((src_6mnMN & 16) << 2) | + ((src_7opOP & 16) << 3); + + // b1: cdCD|abAB (lower 4 bits for first 2 values) + b_ptr[row_base + 5 * k8 + 1] = (src_0abAB & 15) | ((src_1cdCD & 15) << 4); + + // b2: ghGH|efEF (lower 4 bits for second 2 values) + b_ptr[row_base + 5 * k8 + 2] = (src_2efEF & 15) | ((src_3ghGH & 15) << 4); + + // b3: klKL|ijIJ (lower 4 bits for third 2 values) + b_ptr[row_base + 5 * k8 + 3] = (src_4ijIJ & 15) | ((src_5klKL & 15) << 4); + + // b4: opOP|mnMN (lower 4 bits for last 2 values) + b_ptr[row_base + 5 * k8 + 4] = (src_6mnMN & 15) | ((src_7opOP & 15) << 4); + } + } +} + +/** + * 6-bit packing. Each weight is 6 bits. So we pack 4 weights into 3 bytes. We + * pack the upper 2 bits of all 4 weights into the first 2 bytes, then the 4 + * lower bits of all weights into the other 4 bytes. + */ +template <> +inline void +pack<6>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = (n * (K / 4)) * 3; + for (int32_t k4 = 0; k4 < K / 4; k4++) { + uint8_t src_10abcd = w_ptr[n * K + k4 * 4 + 0]; + uint8_t src_32efgh = w_ptr[n * K + k4 * 4 + 1]; + uint8_t src_54ijkl = w_ptr[n * K + k4 * 4 + 2]; + uint8_t src_76mnop = w_ptr[n * K + k4 * 4 + 3]; + + // b0: 76|54|32|10 (upper 2 bits for all values) + b_ptr[row_base + 3 * k4 + 0] = ((src_10abcd & 48) >> 4) | + ((src_32efgh & 48) >> 2) | ((src_54ijkl & 48)) | + ((src_76mnop & 48) << 2); + + // b1: efgh|abcd (lower 4 bits for first 2 values) + b_ptr[row_base + 3 * k4 + 1] = + (src_10abcd & 15) | ((src_32efgh & 15) << 4); + + // b2: mnop|ijkl (lower 4 bits for last 2 values) + b_ptr[row_base + 3 * k4 + 2] = + (src_54ijkl & 15) | ((src_76mnop & 15) << 4); + } + } +} + +/** + * 7-bit packing. Each weight is 7 bits. So we pack 8 weights into 7 bytes. + * Each of the 7 bytes contains 1 weight, plus 1 bit from the 8th weight. So, + * this packing spreads the 8th weight across all 7 bytes. The upper bit of + * each byte is the bit from the 8th weight. + */ +template <> +inline void +pack<7>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { + for (int32_t n = 0; n < N; n++) { + int32_t row_base = (n * (K / 8)) * 7; + for (int32_t k8 = 0; k8 < K / 8; k8++) { + uint8_t src_0 = w_ptr[n * K + k8 * 8 + 0]; + uint8_t src_1 = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_2 = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_3 = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_4 = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_5 = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_6 = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_7 = w_ptr[n * K + k8 * 8 + 7]; + + b_ptr[row_base + 7 * k8 + 0] = src_0 | ((src_7 & 1) << 7); + b_ptr[row_base + 7 * k8 + 1] = src_1 | ((src_7 & 2) << 6); + b_ptr[row_base + 7 * k8 + 2] = src_2 | ((src_7 & 4) << 5); + b_ptr[row_base + 7 * k8 + 3] = src_3 | ((src_7 & 8) << 4); + b_ptr[row_base + 7 * k8 + 4] = src_4 | ((src_7 & 16) << 3); + b_ptr[row_base + 7 * k8 + 5] = src_5 | ((src_7 & 32) << 2); + b_ptr[row_base + 7 * k8 + 6] = src_6 | ((src_7 & 64) << 1); + } + } +} + +} // namespace packing +} // namespace lowbit +} // namespace mps +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/kernels/mps/test/Makefile b/torchao/experimental/kernels/mps/test/Makefile new file mode 100644 index 0000000000..e8213818c5 --- /dev/null +++ b/torchao/experimental/kernels/mps/test/Makefile @@ -0,0 +1,7 @@ +all: test_lowbit + +test_lowbit: test_lowbit.mm + clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $< -framework Metal -framework Foundation + +run: test_lowbit + ./test_lowbit diff --git a/torchao/experimental/kernels/mps/test/bfloat16.h b/torchao/experimental/kernels/mps/test/bfloat16.h new file mode 100644 index 0000000000..b041c1b492 --- /dev/null +++ b/torchao/experimental/kernels/mps/test/bfloat16.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +/** + * This implementation is copied from + * executorch/runtime/core/portable_type/bfloat16.h + */ + +inline float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; + std::memcpy(&res, &tmp, sizeof(tmp)); + return res; +} + +inline uint16_t bits_from_f32(float src) { + uint32_t res = 0; + std::memcpy(&res, &src, sizeof(res)); + return res >> 16; +} + +inline uint16_t round_to_nearest_even(float src) { + if (std::isnan(src)) { + return UINT16_C(0x7FC0); + } + uint32_t U32 = 0; + std::memcpy(&U32, &src, sizeof(U32)); + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); +} + +/** + * The "brain floating-point" type, compatible with c10/util/BFloat16.h from + * pytorch core. + * + * This representation uses 1 bit for the sign, 8 bits for the exponent and 7 + * bits for the mantissa. + */ +struct alignas(2) BFloat16 { + uint16_t x; + + BFloat16() = default; + struct from_bits_t {}; + static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + constexpr BFloat16(unsigned short bits, from_bits_t) : x(bits) {} + /* implicit */ BFloat16(float value) : x(round_to_nearest_even(value)) {} + operator float() const { + return f32_from_bits(x); + } +}; diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm new file mode 100644 index 0000000000..2342561322 --- /dev/null +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -0,0 +1,241 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +using Float16 = _Float16; + +template +const std::string_view type_string(); +template <> +const std::string_view type_string() { + return "bfloat"; +} +template <> +const std::string_view type_string() { + return "float"; +} +template <> +const std::string_view type_string() { + return "half"; +} + +inline id allocSharedBuffer(id device, unsigned length) { + id rc = [device newBufferWithLength:length + options:MTLResourceStorageModeShared]; + if (rc == nil) { + fail("Can't allocate " + std::to_string(length) + " bytes on GPU"); + } + return rc; +} + +namespace torchao { +namespace kernels { +namespace mps { +namespace lowbit { + +// Reference CPU implementation of lowbit quantized linear +template +void reference_linear_lowbit_quant_weights_cpu( + const T* a_ptr, + const uint8_t* w_ptr, + int64_t group_size, + const T* sz_ptr, + T* out_ptr, + int32_t M, + int32_t K, + int32_t N, + int64_t nbit) { + uint8_t zero_shift = 1 << (nbit - 1); + + for (int32_t m = 0; m < M; m++) { + for (int32_t n = 0; n < N; n++) { + const int32_t k_block = (K + group_size - 1) / group_size; + const T* A_ptr = a_ptr + m * K; + + float rc = 0.0; + int32_t k = 0; + for (int32_t kb = 0; kb < k_block; kb++) { + const float scale = float(sz_ptr[(kb * N + n) * 2 + 0]); + const float zero = + float(sz_ptr[(kb * N + n) * 2 + 1]) - scale * float(zero_shift); + for (int32_t idx = 0; idx < group_size && k < K; idx++, k++) { + const auto a_val = float(A_ptr[k]); + uint8_t w_val = w_ptr[n * K + k]; + rc += a_val * (scale * float(w_val) + zero); + } + } + + out_ptr[m * N + n] = T(rc); + } + } +} + +template +class LowBitTester { + public: + LowBitTester(int32_t m, int32_t k, int32_t n, int32_t group_size) + : M(m), K(k), N(n), qGroupSize(group_size) {} + + void init() { + allocBuffers(MTL_DEVICE); + + T* a_ptr = reinterpret_cast([buf_A contents]); + uint8_t* w_ptr = reinterpret_cast([buf_W contents]); + T* c_ptr = reinterpret_cast([buf_C contents]); + T* s_ptr = reinterpret_cast([buf_SZ contents]); + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution<> int_distrib(0, (1 << nbit) - 1); + std::uniform_real_distribution<> real_distrib(-1.0, 1.0); + + for (int idx = 0; idx < M * K; ++idx) { + a_ptr[idx] = real_distrib(generator); + } + for (int idx = 0; idx < N * K; ++idx) { + w_ptr[idx] = int_distrib(generator); + } + int32_t ceil_K_group_size = (K + qGroupSize - 1) / qGroupSize; + for (int idx = 0; idx < N * ceil_K_group_size; ++idx) { + s_ptr[2 * idx] = (idx + 1.0) / N; + s_ptr[2 * idx + 1] = 0; + } + for (int idx = 0; idx < M * N; ++idx) { + c_ptr[idx] = -1.0; + } + } + + void pack() { + uint8_t* w_ptr = reinterpret_cast([buf_W contents]); + uint8_t* b_ptr = reinterpret_cast([buf_B contents]); + LowBitQuantWeights::pack(w_ptr, b_ptr, N, K); + } + + void linear() { + LowBitQuantWeights::linear( + buf_A, buf_B, qGroupSize, buf_SZ, buf_C, M, K, N, type_string()); + } + + bool validate(float atol_lim = 5e-3, float rtol_lim = 5e-3) const { + T* a_ptr = reinterpret_cast([buf_A contents]); + uint8_t* w_ptr = reinterpret_cast([buf_W contents]); + T* c_ptr = reinterpret_cast([buf_C contents]); + T* sz_ptr = reinterpret_cast([buf_SZ contents]); + + char* e_ptr_f = new char[M * N * sizeof(T)]; // expected + T* e_ptr = reinterpret_cast(e_ptr_f); + reference_linear_lowbit_quant_weights_cpu( + a_ptr, w_ptr, qGroupSize, sz_ptr, e_ptr, M, K, N, nbit); + + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + float rc = float(c_ptr[m * N + n]); + float expected = float(e_ptr[m * N + n]); + + auto atol = std::abs(rc - expected); + auto rtol = + atol / std::max(std::min(std::abs(expected), std::abs(rc)), 1e-6f); + if (rtol > rtol_lim && atol > atol_lim) { + std::cerr << "Result " << expected << " vs expected " << rc + << " (atol=" << atol << " ,rtol=" << rtol << ") at " << m + << ":" << n << std::endl; + return false; + } + } + } + return true; + } + + private: + void allocBuffers(id device) { + int32_t ceil_K_group_size = (K + qGroupSize - 1) / qGroupSize; + const int32_t elem_size = sizeof(T); + buf_A = allocSharedBuffer(device, M * K * elem_size); + buf_W = allocSharedBuffer(device, N * K); + buf_B = allocSharedBuffer(device, N * nbit * K / 8); + buf_C = allocSharedBuffer(device, M * N * elem_size); + buf_SZ = allocSharedBuffer(device, N * ceil_K_group_size * 2 * elem_size); + } + + public: + int32_t M, K, N; // Input-output matirx dims + int32_t qGroupSize; + id buf_A; // MxK elements + id buf_W; // NxK elements + id buf_B; // NxK elements (packed) + id buf_C; // MxN elements + id buf_SZ; // (K/group_size)xNx2 elements +}; + +} // namespace lowbit +} // namespace mps +} // namespace kernels +} // namespace torchao + +template +void run_test(int32_t m, int32_t k, int32_t n, int32_t group_size) { + torchao::kernels::mps::lowbit::LowBitTester tester( + m, k, n, group_size); + tester.init(); + tester.pack(); + tester.linear(); + bool success = tester.validate(); + std::cout << "Test " << type_string() << " " << nbit << "-bit " << m + << "x" << k << "x" << n << " group size: " << group_size << " " + << (success ? "succeeded" : "failed") << std::endl; +} + +template +void run_test_battery() { + run_test(1, 8, 1, 32); + run_test(1, 32, 1, 32); + run_test(1, 32, 1, 64); + run_test(1, 56, 1, 64); + run_test(1, 64, 1, 64); + run_test(1, 72, 1, 64); + run_test(1, 1000, 1, 64); + run_test(3, 64, 5, 64); + run_test(7, 64, 23, 64); + run_test(17, 120, 23, 128); + run_test(17, 128, 23, 128); + run_test(41, 144, 23, 128); + run_test(41, 128, 23, 128); + run_test(81, 8, 1, 256); + run_test(19, 256, 17, 256); + run_test(1, 1000, 81, 256); +} + +int main() { + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + run_test_battery(); + + return 0; +} diff --git a/torchao/experimental/ops/mps/register.mm b/torchao/experimental/ops/mps/register.mm new file mode 100644 index 0000000000..a0fe15a8cd --- /dev/null +++ b/torchao/experimental/ops/mps/register.mm @@ -0,0 +1,145 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// clang-format off +#include +#include +#include +// clang-format on + +namespace torchao { +namespace kernels { +namespace mps { +namespace lowbit { +namespace aten { + +using Tensor = at::Tensor; +using namespace at::native::mps; + +// LowBit Quantized Linear on MPS Backend +template +Tensor linear_mps_kernel( + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& SZ) { + auto M = A.size(0); + auto N = B.size(0); + auto K = A.size(1); + + TORCH_CHECK(A.is_mps(), __func__, "A is on ", A.device(), " but expected on mps"); + TORCH_CHECK(B.is_mps(), __func__, "B is on ", B.device(), " but expected on mps"); + TORCH_CHECK(SZ.is_mps(), __func__, "SZ is on ", SZ.device(), " but expected on mps"); + + TORCH_CHECK(A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat, + __func__, + " : expect A to be either 32-bit or 16-bit float tensor."); + TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous."); + TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor."); + + TORCH_CHECK(B.dtype() == kByte, __func__, " : expect B to be uint8 tensor."); + TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous."); + TORCH_CHECK(B.size(1) == K, __func__, " : expect B.size(1) == ", K); + + TORCH_CHECK(K % 8 == 0, + __func__, + ": expect K to be multiple of 8, got ", + K); + + TORCH_CHECK(group_size == 32 || group_size == 64 || group_size == 128 || group_size == 256, + __func__, + ": expect group_size to be 32, 64, 128 or 256, got ", + group_size); + + TORCH_CHECK(SZ.dim() == 3 && SZ.size(1) == N && SZ.size(2) == 2, + __func__, + ": expect SZ to be 3d tensor with sizes [:, ", + N, + ", 2]"); + + auto C = at::empty({M, N}, A.options()); + + LowBitQuantWeights::linear( + getMTLBufferStorage(A), + getMTLBufferStorage(B), + group_size, + getMTLBufferStorage(SZ), + getMTLBufferStorage(C), + M, + K, + N, + scalarToMetalTypeString(A)); + + return C; +} + +// LowBit Packing on CPU Backend +template +Tensor pack_weights_cpu_kernel(const Tensor& W) { + auto N = W.size(0); + auto K = W.size(1); + auto B = at::empty({N, nbit * K / 8}, W.options()); + + uint8_t* w_ptr = W.data_ptr(); + uint8_t* b_ptr = B.data_ptr(); + + LowBitQuantWeights::pack(w_ptr, b_ptr, N, K); + + return B; +} + +// Registers _C as a Python extension module. +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} + +TORCH_LIBRARY(torchao, m) { + m.def("_pack_weight_1bit(Tensor W) -> Tensor"); + m.def("_pack_weight_2bit(Tensor W) -> Tensor"); + m.def("_pack_weight_3bit(Tensor W) -> Tensor"); + m.def("_pack_weight_4bit(Tensor W) -> Tensor"); + m.def("_pack_weight_5bit(Tensor W) -> Tensor"); + m.def("_pack_weight_6bit(Tensor W) -> Tensor"); + m.def("_pack_weight_7bit(Tensor W) -> Tensor"); + m.def( + "_linear_fp_act_1bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_2bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_3bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_4bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_5bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + m.def( + "_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("_pack_weight_1bit", &pack_weights_cpu_kernel<1>); + m.impl("_pack_weight_2bit", &pack_weights_cpu_kernel<2>); + m.impl("_pack_weight_3bit", &pack_weights_cpu_kernel<3>); + m.impl("_pack_weight_4bit", &pack_weights_cpu_kernel<4>); + m.impl("_pack_weight_5bit", &pack_weights_cpu_kernel<5>); + m.impl("_pack_weight_6bit", &pack_weights_cpu_kernel<6>); + m.impl("_pack_weight_7bit", &pack_weights_cpu_kernel<7>); +} + +TORCH_LIBRARY_IMPL(torchao, MPS, m) { + m.impl("_linear_fp_act_1bit_weight", &linear_mps_kernel<1>); + m.impl("_linear_fp_act_2bit_weight", &linear_mps_kernel<2>); + m.impl("_linear_fp_act_3bit_weight", &linear_mps_kernel<3>); + m.impl("_linear_fp_act_4bit_weight", &linear_mps_kernel<4>); + m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel<5>); + m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel<6>); + m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>); +} + +} // namespace aten +} // namespace lowbit +} // namespace mps +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/ops/mps/setup.py b/torchao/experimental/ops/mps/setup.py new file mode 100644 index 0000000000..e9c206cdb9 --- /dev/null +++ b/torchao/experimental/ops/mps/setup.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from setuptools import setup +from torch.utils.cpp_extension import CppExtension, BuildExtension + +setup( + name="torchao_mps_ops", + version="1.0", + ext_modules=[ + CppExtension( + name="torchao_mps_ops", + sources=["register.mm"], + include_dirs=[os.getenv("TORCHAO_ROOT")], + extra_compile_args=["-DATEN=1"], + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py new file mode 100644 index 0000000000..679fceeac4 --- /dev/null +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torchao_mps_ops +import unittest + + +def parameterized(test_cases): + def decorator(func): + def wrapper(self): + for case in test_cases: + with self.subTest(case=case): + func(self, *case) + + return wrapper + + return decorator + + +class TestLowBitQuantWeightsLinear(unittest.TestCase): + cases = [ + (nbit, *param) + for nbit in range(1, 8) + for param in [ + (1, 32, 32, 32), + (128, 48, 64, 32), + (17, 1024, 512, 256), + ] + ] + + def _init_tensors(self, group_size, M, K, N, nbit, device="mps"): + max_abs = 1 << (nbit - 1) + ceil_K_group_size = (K + group_size - 1) // group_size + A = 2 * torch.rand(M, K, dtype=torch.float32, device=device) - 1 + W = torch.randint(0, 2 * max_abs, (N, K), dtype=torch.uint8, device=device) + S = torch.rand(ceil_K_group_size, N, dtype=torch.float32, device=device) + 0.01 + Z = torch.randint( + -max_abs, + max_abs, + (ceil_K_group_size, N), + dtype=torch.float32, + device=device, + ) + SZ = torch.stack((S, Z), dim=2) + return A, W, SZ + + def _reference_linear_lowbit_quant_weights(self, A, W, group_size, SZ, nbit): + # A is (M, K) + # W is (N, K) + # SZ is (K // group_size, N, 2) + N = W.shape[0] + K = W.shape[1] + max_abs = 1 << (nbit - 1) + W = W.to(torch.float32) - max_abs + scales = ( + SZ[:, :, 0].t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + ) + zeros = SZ[:, :, 1].t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + W = scales * W + zeros + return torch.mm(A, W.t()) + + @parameterized(cases) + def test_linear(self, nbit, M=1, K=32, N=32, group_size=32): + print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}") + A, W, SZ = self._init_tensors(group_size, M, K, N, nbit=nbit) + packing_op = getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") + linear_op = getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + B = packing_op(W.cpu()).to("mps") + result = linear_op(A, B, group_size, SZ).cpu() + expected = self._reference_linear_lowbit_quant_weights( + A.cpu(), W.cpu(), group_size, SZ.cpu(), nbit=nbit + ) + torch.testing.assert_close(result, expected, rtol=0.001, atol=0.001) + + +if __name__ == "__main__": + unittest.main()