Skip to content

Commit

Permalink
Add 2bit packing routines
Browse files Browse the repository at this point in the history
Summary:
Adds 2-bit packing/unpacking routines, together with tests/benchmarks.

This adds 24 new 2 bit kernels (3 tile sizes x 8 variants).

Reviewed By: digantdesai

Differential Revision: D62133659
  • Loading branch information
metascroy authored and facebook-github-bot committed Sep 3, 2024
1 parent e15e509 commit 73618d8
Show file tree
Hide file tree
Showing 5 changed files with 461 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,144 @@

#include <arm_neon.h>
#include <benchmark/benchmark.h>
#include <iostream>

#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
#include <cassert>

namespace {

// Benchmark utility to compare variants of uint2 packing
void pack_uint2_values(
uint8_t* packed,
uint8_t* unpacked,
int packed_size,
int unpacked_size,
int variant) {
constexpr int nbit = 2;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);

uint8x8_t unpacked0_8x8;
uint8x8_t unpacked1_8x8;
uint8x8_t unpacked2_8x8;
uint8x8_t unpacked3_8x8;

uint8x16_t unpacked0_8x16;
uint8x16_t unpacked1_8x16;
uint8x16_t unpacked2_8x16;
uint8x16_t unpacked3_8x16;

switch (variant) {
case 4:
for (int i = 0; i < unpacked_size; i += 4) {
torchao::bitpacking::internal::pack_4_uint2_values(
packed + ((i * nbit) / bitsPerByte), unpacked + i);
}
break;
case 32:
for (int i = 0; i < unpacked_size; i += 32) {
torchao::bitpacking::internal::vec_load_32_uint8_values(
unpacked0_8x8,
unpacked1_8x8,
unpacked2_8x8,
unpacked3_8x8,
unpacked + i);
torchao::bitpacking::internal::vec_pack_32_uint2_values(
packed + ((i * nbit) / bitsPerByte),
unpacked0_8x8,
unpacked1_8x8,
unpacked2_8x8,
unpacked3_8x8);
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked0_8x16,
unpacked1_8x16,
unpacked2_8x16,
unpacked3_8x16,
unpacked + i);
torchao::bitpacking::internal::vec_pack_64_uint2_values(
packed + ((i * nbit) / bitsPerByte),
unpacked0_8x16,
unpacked1_8x16,
unpacked2_8x16,
unpacked3_8x16);
}
break;
}
}

// Benchmark utility to compare variants of uint2 packing
void unpack_uint2_values(
uint8_t* unpacked,
uint8_t* packed,
int unpacked_size,
int packed_size,
int variant) {
constexpr int nbit = 2;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);

uint8x8_t unpacked0_8x8;
uint8x8_t unpacked1_8x8;
uint8x8_t unpacked2_8x8;
uint8x8_t unpacked3_8x8;

uint8x16_t unpacked0_8x16;
uint8x16_t unpacked1_8x16;
uint8x16_t unpacked2_8x16;
uint8x16_t unpacked3_8x16;

switch (variant) {
case 4:
for (int i = 0; i < unpacked_size; i += 4) {
torchao::bitpacking::internal::unpack_4_uint2_values(
unpacked + i, packed + ((i * nbit) / bitsPerByte));
}
break;
case 32:
for (int i = 0; i < unpacked_size; i += 32) {
torchao::bitpacking::internal::vec_unpack_32_uint2_values(
unpacked0_8x8,
unpacked1_8x8,
unpacked2_8x8,
unpacked3_8x8,
packed + ((i * nbit) / bitsPerByte));
torchao::bitpacking::internal::vec_store_32_uint8_values(
unpacked + i,
unpacked0_8x8,
unpacked1_8x8,
unpacked2_8x8,
unpacked3_8x8);
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_unpack_64_uint2_values(
unpacked0_8x16,
unpacked1_8x16,
unpacked2_8x16,
unpacked3_8x16,
packed + ((i * nbit) / bitsPerByte));
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i,
unpacked0_8x16,
unpacked1_8x16,
unpacked2_8x16,
unpacked3_8x16);
}
break;
}
}

// Benchmark utility to compare variants of uint3 packing
void pack_uint3_values(
uint8_t* packed,
Expand Down Expand Up @@ -220,6 +348,44 @@ void unpack_uint4_values(

} // namespace

static void benchmark_pack_uint2_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
int nbit = 2;

assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);

for (auto _ : state) {
pack_uint2_values(
packed.data(), unpacked.data(), packed_size, unpacked_size, variant);
}
}

static void benchmark_unpack_uint2_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
int nbit = 2;

assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = torchao::get_random_lowbit_vector(packed_size, 8);
auto unpacked = std::vector<uint8_t>(unpacked_size, 0);

for (auto _ : state) {
unpack_uint2_values(
unpacked.data(),
packed.data(),
unpacked.size(),
packed.size(),
variant);
}
}

static void benchmark_pack_uint3_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
Expand Down Expand Up @@ -296,6 +462,8 @@ static void benchmark_unpack_uint4_values(benchmark::State& state) {
}
}

BENCHMARK(benchmark_pack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
BENCHMARK(benchmark_unpack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_unpack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_pack_uint4_values)->ArgsProduct({{128}, {2, 16, 32}});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,20 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
false>) \
->ArgsProduct(BENCHMARK_PARAMS)

BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
3);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
3);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
3);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
Expand Down
73 changes: 67 additions & 6 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once
#include <arm_neon.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
#include <cassert>
Expand All @@ -15,6 +16,30 @@ namespace torchao {
namespace bitpacking {

namespace internal {
TORCHAO_ALWAYS_INLINE inline void vec_store_32_uint8_values(
uint8_t* dest,
const uint8x8_t& vec0,
const uint8x8_t& vec1,
const uint8x8_t& vec2,
const uint8x8_t& vec3) {
vst1_u8(dest, vec0);
vst1_u8(dest + 8, vec1);
vst1_u8(dest + 16, vec2);
vst1_u8(dest + 24, vec3);
}

TORCHAO_ALWAYS_INLINE inline void vec_load_32_uint8_values(
uint8x8_t& vec0,
uint8x8_t& vec1,
uint8x8_t& vec2,
uint8x8_t& vec3,
const uint8_t* src) {
vec0 = vld1_u8(src);
vec1 = vld1_u8(src + 8);
vec2 = vld1_u8(src + 16);
vec3 = vld1_u8(src + 24);
}

TORCHAO_ALWAYS_INLINE inline void vec_store_64_uint8_values(
uint8_t* dest,
const uint8x16_t& vec0,
Expand Down Expand Up @@ -49,7 +74,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

// Shift unpacked values to nonnegative range
Expand All @@ -58,6 +83,13 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift));

switch (nbit) {
case 2:
torchao::bitpacking::internal::vec_pack_32_uint2_values(
packed,
vget_low_u8(shifted0),
vget_high_u8(shifted0),
vget_low_u8(shifted1),
vget_high_u8(shifted1));
case 3:
uint8_t buffer[32];
vst1q_u8(buffer, shifted0);
Expand Down Expand Up @@ -89,13 +121,22 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

uint8x16_t shifted0;
uint8x16_t shifted1;

switch (nbit) {
case 2:
uint8x8_t shifted0_low;
uint8x8_t shifted0_high;
uint8x8_t shifted1_low;
uint8x8_t shifted1_high;
torchao::bitpacking::internal::vec_unpack_32_uint2_values(
shifted0_low, shifted0_high, shifted1_low, shifted1_high, packed);
shifted0 = vcombine_u8(shifted0_low, shifted0_high);
shifted1 = vcombine_u8(shifted1_low, shifted1_high);
case 3:
uint8_t buffer[32];
torchao::bitpacking::internal::unpack_8_uint3_values(buffer, packed);
Expand Down Expand Up @@ -133,7 +174,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

// Shift unpacked values to nonnegative range
Expand All @@ -144,6 +185,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift));

switch (nbit) {
case 2:
torchao::bitpacking::internal::vec_pack_64_uint2_values(
packed, shifted0, shifted1, shifted2, shifted3);
break;
case 3:
torchao::bitpacking::internal::vec_pack_64_uint3_values(
packed, shifted0, shifted1, shifted2, shifted3);
Expand All @@ -170,7 +215,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

uint8x16_t shifted0;
Expand All @@ -179,6 +224,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
uint8x16_t shifted3;

switch (nbit) {
case 2:
torchao::bitpacking::internal::vec_unpack_64_uint2_values(
shifted0, shifted1, shifted2, shifted3, packed);
break;
case 3:
torchao::bitpacking::internal::vec_unpack_64_uint3_values(
shifted0, shifted1, shifted2, shifted3, packed);
Expand Down Expand Up @@ -216,7 +265,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

// Shift unpacked values to nonnegative range
Expand All @@ -231,6 +280,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
uint8x16_t shifted7 = vreinterpretq_u8_s8(vaddq_s8(unpacked7, shift));

switch (nbit) {
case 2:
torchao::bitpacking::internal::vec_pack_64_uint2_values(
packed, shifted0, shifted1, shifted2, shifted3);
torchao::bitpacking::internal::vec_pack_64_uint2_values(
packed + 16, shifted4, shifted5, shifted6, shifted7);
break;
case 3:
torchao::bitpacking::internal::vec_pack_128_uint3_values(
packed,
Expand Down Expand Up @@ -273,7 +328,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

uint8x16_t shifted0;
Expand All @@ -286,6 +341,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
uint8x16_t shifted7;

switch (nbit) {
case 2:
torchao::bitpacking::internal::vec_unpack_64_uint2_values(
shifted0, shifted1, shifted2, shifted3, packed);
torchao::bitpacking::internal::vec_unpack_64_uint2_values(
shifted4, shifted5, shifted6, shifted7, packed + 16);
break;
case 3:
torchao::bitpacking::internal::vec_unpack_128_uint3_values(
shifted0,
Expand Down
Loading

0 comments on commit 73618d8

Please sign in to comment.