Skip to content

Commit

Permalink
Introduce lowbit quantized linear MPS kernels (#954)
Browse files Browse the repository at this point in the history
Summary:

The following is the directory structure of the submitted code under torchao

```
experimental/
├── kernels/
│   └── mps/
│       ├── metal/
│       │   └── (metal shaders)
│       ├── src/
│       │   └── (tensor agnostic mps kernel implementations)
│       └── test/
│       │   └── (directly test mps kernel implementations)
└── ops/
    └── mps/
        ├── register.mm
        ├── setup.py
        └── test/
            └── (test torch custom ops)
```

Differential Revision: D63342895
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Sep 26, 2024
1 parent b149edb commit e783565
Show file tree
Hide file tree
Showing 20 changed files with 2,037 additions and 0 deletions.
97 changes: 97 additions & 0 deletions torchao/experimental/kernels/mps/metal/divbit.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#include <metal_stdlib>
using namespace metal;

// dispatchThreads:MTLSizeMake(N, M, 1)

template<unsigned nbit, typename T, unsigned groupSize>
kernel void divbit_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B;

constexpr uint8_t zero_shift = 1 << (nbit - 1);
constexpr uint8_t inv_nbit = 8 / nbit;
constexpr uint8_t minimask = (1 << nbit) - 1;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const T scale = scalesAndZeros[(kb * N + n) * 2 + 0];
const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift);
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) {
const auto a_val = float(A_ptr[k]);
uint8_t b_val = B_ptr[(n * K + k) / inv_nbit];
uint8_t shift = nbit * (k % inv_nbit);
uint8_t mask = minimask << shift;
b_val = (b_val & mask) >> shift;
rc += a_val * float(scale * T(b_val) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \
template \
[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void divbit_mm<NBIT, DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_DIVBIT_MM(1, float, 32);
INSTANTIATE_DIVBIT_MM(1, half, 32);
INSTANTIATE_DIVBIT_MM(1, float, 64);
INSTANTIATE_DIVBIT_MM(1, half, 64);
INSTANTIATE_DIVBIT_MM(1, float, 128);
INSTANTIATE_DIVBIT_MM(1, half, 128);
INSTANTIATE_DIVBIT_MM(1, float, 256);
INSTANTIATE_DIVBIT_MM(1, half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_DIVBIT_MM(1, bfloat, 32);
INSTANTIATE_DIVBIT_MM(1, bfloat, 64);
INSTANTIATE_DIVBIT_MM(1, bfloat, 128);
INSTANTIATE_DIVBIT_MM(1, bfloat, 256);
#endif

INSTANTIATE_DIVBIT_MM(2, float, 32);
INSTANTIATE_DIVBIT_MM(2, half, 32);
INSTANTIATE_DIVBIT_MM(2, float, 64);
INSTANTIATE_DIVBIT_MM(2, half, 64);
INSTANTIATE_DIVBIT_MM(2, float, 128);
INSTANTIATE_DIVBIT_MM(2, half, 128);
INSTANTIATE_DIVBIT_MM(2, float, 256);
INSTANTIATE_DIVBIT_MM(2, half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_DIVBIT_MM(2, bfloat, 32);
INSTANTIATE_DIVBIT_MM(2, bfloat, 64);
INSTANTIATE_DIVBIT_MM(2, bfloat, 128);
INSTANTIATE_DIVBIT_MM(2, bfloat, 256);
#endif

INSTANTIATE_DIVBIT_MM(4, float, 32);
INSTANTIATE_DIVBIT_MM(4, half, 32);
INSTANTIATE_DIVBIT_MM(4, float, 64);
INSTANTIATE_DIVBIT_MM(4, half, 64);
INSTANTIATE_DIVBIT_MM(4, float, 128);
INSTANTIATE_DIVBIT_MM(4, half, 128);
INSTANTIATE_DIVBIT_MM(4, float, 256);
INSTANTIATE_DIVBIT_MM(4, half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_DIVBIT_MM(4, bfloat, 32);
INSTANTIATE_DIVBIT_MM(4, bfloat, 64);
INSTANTIATE_DIVBIT_MM(4, bfloat, 128);
INSTANTIATE_DIVBIT_MM(4, bfloat, 256);
#endif
63 changes: 63 additions & 0 deletions torchao/experimental/kernels/mps/metal/int1mm.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <metal_stdlib>
using namespace metal;

// dispatchThreads:MTLSizeMake(N, M, 1)

template<typename T, unsigned groupSize>
kernel void int1pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale;
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) {
const auto a_val = float(A_ptr[k]);
uint8_t b_val = B_ptr[(n * K + k) / 8];
uint8_t shift = (k % 8);
uint8_t mask = 1 << shift;
b_val = (b_val & mask) >> shift;
rc += a_val * (scale * float(b_val) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_INT1MM(DTYPE, GSIZE) \
template \
[[host_name("int1pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int1pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT1MM(float, 32);
INSTANTIATE_INT1MM(half, 32);
INSTANTIATE_INT1MM(float, 64);
INSTANTIATE_INT1MM(half, 64);
INSTANTIATE_INT1MM(float, 128);
INSTANTIATE_INT1MM(half, 128);
INSTANTIATE_INT1MM(float, 256);
INSTANTIATE_INT1MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT1MM(bfloat, 32);
INSTANTIATE_INT1MM(bfloat, 64);
INSTANTIATE_INT1MM(bfloat, 128);
INSTANTIATE_INT1MM(bfloat, 256);
#endif
63 changes: 63 additions & 0 deletions torchao/experimental/kernels/mps/metal/int2mm.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <metal_stdlib>
using namespace metal;

// dispatchThreads:MTLSizeMake(N, M, 1)

template<typename T, unsigned groupSize>
kernel void int2pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = scalesAndZeros[(kb * N + n) * 2 + 0];
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(2);
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) {
const auto a_val = float(A_ptr[k]);
uint8_t b_val = B_ptr[(n * K + k) / 4];
uint8_t shift = 2 * (k % 4);
uint8_t mask = 3 << shift;
b_val = (b_val & mask) >> shift;
rc += a_val * (scale * float(b_val) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_INT2MM(DTYPE, GSIZE) \
template \
[[host_name("int2pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int2pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT2MM(float, 32);
INSTANTIATE_INT2MM(half, 32);
INSTANTIATE_INT2MM(float, 64);
INSTANTIATE_INT2MM(half, 64);
INSTANTIATE_INT2MM(float, 128);
INSTANTIATE_INT2MM(half, 128);
INSTANTIATE_INT2MM(float, 256);
INSTANTIATE_INT2MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT2MM(bfloat, 32);
INSTANTIATE_INT2MM(bfloat, 64);
INSTANTIATE_INT2MM(bfloat, 128);
INSTANTIATE_INT2MM(bfloat, 256);
#endif
88 changes: 88 additions & 0 deletions torchao/experimental/kernels/mps/metal/int3mm.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include <metal_stdlib>
using namespace metal;

// dispatchThreads:MTLSizeMake(N, M, 1)

template<typename T, unsigned groupSize>
kernel void int3pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B + n * 3 * K / 8;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
const auto a_val2 = float(A_ptr[k + 2]);
const auto a_val3 = float(A_ptr[k + 3]);
const auto a_val4 = float(A_ptr[k + 4]);
const auto a_val5 = float(A_ptr[k + 5]);
const auto a_val6 = float(A_ptr[k + 6]);
const auto a_val7 = float(A_ptr[k + 7]);

uchar b0 = B_ptr[3 * (k / 8) + 0];
uchar b1 = B_ptr[3 * (k / 8) + 1];
uchar b2 = B_ptr[3 * (k / 8) + 2];

uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3);
uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2);
uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4);
uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6);

uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3);
uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2);
uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4);
uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6);

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
rc += a_val2 * (scale * float(w_val2) + zero);
rc += a_val3 * (scale * float(w_val3) + zero);
rc += a_val4 * (scale * float(w_val4) + zero);
rc += a_val5 * (scale * float(w_val5) + zero);
rc += a_val6 * (scale * float(w_val6) + zero);
rc += a_val7 * (scale * float(w_val7) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \
template \
[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int3pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT3MM(float, 32);
INSTANTIATE_INT3MM(half, 32);
INSTANTIATE_INT3MM(float, 64);
INSTANTIATE_INT3MM(half, 64);
INSTANTIATE_INT3MM(float, 128);
INSTANTIATE_INT3MM(half, 128);
INSTANTIATE_INT3MM(float, 256);
INSTANTIATE_INT3MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT3MM(bfloat, 32);
INSTANTIATE_INT3MM(bfloat, 64);
INSTANTIATE_INT3MM(bfloat, 128);
INSTANTIATE_INT3MM(bfloat, 256);
#endif
Loading

0 comments on commit e783565

Please sign in to comment.