Skip to content

Commit

Permalink
Introduce lowbit quantized linear MPS kernels (#954)
Browse files Browse the repository at this point in the history
Summary:

The following is the directory structure of the submitted code under torchao

```
experimental/
├── kernels/
│   └── mps/
│       ├── metal/
│       │   └── (metal shaders)
│       ├── src/
│       │   └── (tensor agnostic mps kernel implementations)
│       └── test/
│       │   └── (directly test mps kernel implementations)
└── ops/
    └── mps/
        ├── register.mm
        ├── setup.py
        └── test/
            └── (test torch custom ops)
```

Differential Revision: D63342895
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Oct 10, 2024
1 parent 107e378 commit 6c16c82
Show file tree
Hide file tree
Showing 15 changed files with 1,641 additions and 0 deletions.
106 changes: 106 additions & 0 deletions torchao/experimental/kernels/mps/metal/divbit.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#include <metal_stdlib>
using namespace metal;

/**
* LowBit Quantized Linear for bitwidths that are divisors of 8. Hence the name.
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
* Dispatched threads: N x M x 1
*/
template<typename T, unsigned nbit, unsigned groupSize>
kernel void divbit_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B;

constexpr uint8_t zero_shift = 1 << (nbit - 1);
constexpr uint8_t values_per_byte = 8 / nbit;
constexpr uint8_t minimask = (1 << nbit) - 1;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const T scale = scalesAndZeros[(kb * N + n) * 2 + 0];
const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift);
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) {
const auto a_val = float(A_ptr[k]);
uint8_t b_val = B_ptr[(n * K + k) / values_per_byte];
uint8_t shift = nbit * (k % values_per_byte);
uint8_t mask = minimask << shift;
b_val = (b_val & mask) >> shift;
rc += a_val * float(scale * T(b_val) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \
template \
[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void divbit_mm<DTYPE, NBIT, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_DIVBIT_MM(1, float, 32);
INSTANTIATE_DIVBIT_MM(1, half, 32);
INSTANTIATE_DIVBIT_MM(1, float, 64);
INSTANTIATE_DIVBIT_MM(1, half, 64);
INSTANTIATE_DIVBIT_MM(1, float, 128);
INSTANTIATE_DIVBIT_MM(1, half, 128);
INSTANTIATE_DIVBIT_MM(1, float, 256);
INSTANTIATE_DIVBIT_MM(1, half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_DIVBIT_MM(1, bfloat, 32);
INSTANTIATE_DIVBIT_MM(1, bfloat, 64);
INSTANTIATE_DIVBIT_MM(1, bfloat, 128);
INSTANTIATE_DIVBIT_MM(1, bfloat, 256);
#endif

INSTANTIATE_DIVBIT_MM(2, float, 32);
INSTANTIATE_DIVBIT_MM(2, half, 32);
INSTANTIATE_DIVBIT_MM(2, float, 64);
INSTANTIATE_DIVBIT_MM(2, half, 64);
INSTANTIATE_DIVBIT_MM(2, float, 128);
INSTANTIATE_DIVBIT_MM(2, half, 128);
INSTANTIATE_DIVBIT_MM(2, float, 256);
INSTANTIATE_DIVBIT_MM(2, half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_DIVBIT_MM(2, bfloat, 32);
INSTANTIATE_DIVBIT_MM(2, bfloat, 64);
INSTANTIATE_DIVBIT_MM(2, bfloat, 128);
INSTANTIATE_DIVBIT_MM(2, bfloat, 256);
#endif

INSTANTIATE_DIVBIT_MM(4, float, 32);
INSTANTIATE_DIVBIT_MM(4, half, 32);
INSTANTIATE_DIVBIT_MM(4, float, 64);
INSTANTIATE_DIVBIT_MM(4, half, 64);
INSTANTIATE_DIVBIT_MM(4, float, 128);
INSTANTIATE_DIVBIT_MM(4, half, 128);
INSTANTIATE_DIVBIT_MM(4, float, 256);
INSTANTIATE_DIVBIT_MM(4, half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_DIVBIT_MM(4, bfloat, 32);
INSTANTIATE_DIVBIT_MM(4, bfloat, 64);
INSTANTIATE_DIVBIT_MM(4, bfloat, 128);
INSTANTIATE_DIVBIT_MM(4, bfloat, 256);
#endif
97 changes: 97 additions & 0 deletions torchao/experimental/kernels/mps/metal/int3mm.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#include <metal_stdlib>
using namespace metal;

/**
* 3-Bit Quantized Linear.
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
* Dispatched threads: N x M x 1
*/
template<typename T, unsigned groupSize>
kernel void int3pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B + n * 3 * K / 8;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
const auto a_val2 = float(A_ptr[k + 2]);
const auto a_val3 = float(A_ptr[k + 3]);
const auto a_val4 = float(A_ptr[k + 4]);
const auto a_val5 = float(A_ptr[k + 5]);
const auto a_val6 = float(A_ptr[k + 6]);
const auto a_val7 = float(A_ptr[k + 7]);

uchar b0 = B_ptr[3 * (k / 8) + 0];
uchar b1 = B_ptr[3 * (k / 8) + 1];
uchar b2 = B_ptr[3 * (k / 8) + 2];

uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3);
uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2);
uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4);
uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6);

uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3);
uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2);
uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4);
uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6);

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
rc += a_val2 * (scale * float(w_val2) + zero);
rc += a_val3 * (scale * float(w_val3) + zero);
rc += a_val4 * (scale * float(w_val4) + zero);
rc += a_val5 * (scale * float(w_val5) + zero);
rc += a_val6 * (scale * float(w_val6) + zero);
rc += a_val7 * (scale * float(w_val7) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \
template \
[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int3pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT3MM(float, 32);
INSTANTIATE_INT3MM(half, 32);
INSTANTIATE_INT3MM(float, 64);
INSTANTIATE_INT3MM(half, 64);
INSTANTIATE_INT3MM(float, 128);
INSTANTIATE_INT3MM(half, 128);
INSTANTIATE_INT3MM(float, 256);
INSTANTIATE_INT3MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT3MM(bfloat, 32);
INSTANTIATE_INT3MM(bfloat, 64);
INSTANTIATE_INT3MM(bfloat, 128);
INSTANTIATE_INT3MM(bfloat, 256);
#endif
99 changes: 99 additions & 0 deletions torchao/experimental/kernels/mps/metal/int5mm.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include <metal_stdlib>
using namespace metal;

/**
* 5-Bit Quantized Linear.
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
* Dispatched threads: N x M x 1
*/
template<typename T, unsigned groupSize>
kernel void int5pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B + n * 5 * K / 8;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(16);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
const auto a_val2 = float(A_ptr[k + 2]);
const auto a_val3 = float(A_ptr[k + 3]);
const auto a_val4 = float(A_ptr[k + 4]);
const auto a_val5 = float(A_ptr[k + 5]);
const auto a_val6 = float(A_ptr[k + 6]);
const auto a_val7 = float(A_ptr[k + 7]);

uchar b0 = B_ptr[5 * (k / 8) + 0];
uchar b1 = B_ptr[5 * (k / 8) + 1];
uchar b2 = B_ptr[5 * (k / 8) + 2];
uchar b3 = B_ptr[5 * (k / 8) + 3];
uchar b4 = B_ptr[5 * (k / 8) + 4];

uchar w_val0 = ((b0 & 1) << 4) | (b1 & 15);
uchar w_val1 = ((b0 & 2) << 3) | ((b1 & 240) >> 4);
uchar w_val2 = ((b0 & 4) << 2) | (b2 & 15);
uchar w_val3 = ((b0 & 8) << 1) | ((b2 & 240) >> 4);

uchar w_val4 = ((b0 & 16)) | (b3 & 15);
uchar w_val5 = ((b0 & 32) >> 1) | ((b3 & 240) >> 4);
uchar w_val6 = ((b0 & 64) >> 2) | (b4 & 15);
uchar w_val7 = ((b0 & 128) >> 3) | ((b4 & 240) >> 4);

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
rc += a_val2 * (scale * float(w_val2) + zero);
rc += a_val3 * (scale * float(w_val3) + zero);
rc += a_val4 * (scale * float(w_val4) + zero);
rc += a_val5 * (scale * float(w_val5) + zero);
rc += a_val6 * (scale * float(w_val6) + zero);
rc += a_val7 * (scale * float(w_val7) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_INT5MM(DTYPE, GSIZE) \
template \
[[host_name("int5pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int5pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT5MM(float, 32);
INSTANTIATE_INT5MM(half, 32);
INSTANTIATE_INT5MM(float, 64);
INSTANTIATE_INT5MM(half, 64);
INSTANTIATE_INT5MM(float, 128);
INSTANTIATE_INT5MM(half, 128);
INSTANTIATE_INT5MM(float, 256);
INSTANTIATE_INT5MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT5MM(bfloat, 32);
INSTANTIATE_INT5MM(bfloat, 64);
INSTANTIATE_INT5MM(bfloat, 128);
INSTANTIATE_INT5MM(bfloat, 256);
#endif
84 changes: 84 additions & 0 deletions torchao/experimental/kernels/mps/metal/int6mm.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include <metal_stdlib>
using namespace metal;

/**
* 6-Bit Quantized Linear.
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (6 * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
* Dispatched threads: N x M x 1
*/
template<typename T, unsigned groupSize>
kernel void int6pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B + n * 3 * K / 4;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(32);
for(uint idx = 0; idx < groupSize && k < K; idx+=4, k+=4) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
const auto a_val2 = float(A_ptr[k + 2]);
const auto a_val3 = float(A_ptr[k + 3]);

uchar b0 = B_ptr[3 * (k / 4) + 0];
uchar b1 = B_ptr[3 * (k / 4) + 1];
uchar b2 = B_ptr[3 * (k / 4) + 2];

uchar w_val0 = ((b0 & 3) << 4) | (b1 & 15);
uchar w_val1 = ((b0 & 12) << 2) | ((b1 & 240) >> 4);
uchar w_val2 = ((b0 & 48)) | (b2 & 15);
uchar w_val3 = ((b0 & 192) >> 2) | ((b2 & 240) >> 4);

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
rc += a_val2 * (scale * float(w_val2) + zero);
rc += a_val3 * (scale * float(w_val3) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_INT6MM(DTYPE, GSIZE) \
template \
[[host_name("int6pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int6pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT6MM(float, 32);
INSTANTIATE_INT6MM(half, 32);
INSTANTIATE_INT6MM(float, 64);
INSTANTIATE_INT6MM(half, 64);
INSTANTIATE_INT6MM(float, 128);
INSTANTIATE_INT6MM(half, 128);
INSTANTIATE_INT6MM(float, 256);
INSTANTIATE_INT6MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT6MM(bfloat, 32);
INSTANTIATE_INT6MM(bfloat, 64);
INSTANTIATE_INT6MM(bfloat, 128);
INSTANTIATE_INT6MM(bfloat, 256);
#endif
Loading

0 comments on commit 6c16c82

Please sign in to comment.