Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 2bit packing routines #797

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,144 @@

#include <arm_neon.h>
#include <benchmark/benchmark.h>
#include <iostream>

#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
#include <cassert>

namespace {

// Benchmark utility to compare variants of uint2 packing
void pack_uint2_values(
uint8_t* packed,
uint8_t* unpacked,
int packed_size,
int unpacked_size,
int variant) {
constexpr int nbit = 2;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);

uint8x8_t unpacked0_8x8;
uint8x8_t unpacked1_8x8;
uint8x8_t unpacked2_8x8;
uint8x8_t unpacked3_8x8;

uint8x16_t unpacked0_8x16;
uint8x16_t unpacked1_8x16;
uint8x16_t unpacked2_8x16;
uint8x16_t unpacked3_8x16;

switch (variant) {
case 4:
for (int i = 0; i < unpacked_size; i += 4) {
torchao::bitpacking::internal::pack_4_uint2_values(
packed + ((i * nbit) / bitsPerByte), unpacked + i);
}
break;
case 32:
for (int i = 0; i < unpacked_size; i += 32) {
torchao::bitpacking::internal::vec_load_32_uint8_values(
unpacked0_8x8,
unpacked1_8x8,
unpacked2_8x8,
unpacked3_8x8,
unpacked + i);
torchao::bitpacking::internal::vec_pack_32_uint2_values(
packed + ((i * nbit) / bitsPerByte),
unpacked0_8x8,
unpacked1_8x8,
unpacked2_8x8,
unpacked3_8x8);
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked0_8x16,
unpacked1_8x16,
unpacked2_8x16,
unpacked3_8x16,
unpacked + i);
torchao::bitpacking::internal::vec_pack_64_uint2_values(
packed + ((i * nbit) / bitsPerByte),
unpacked0_8x16,
unpacked1_8x16,
unpacked2_8x16,
unpacked3_8x16);
}
break;
}
}

// Benchmark utility to compare variants of uint2 packing
void unpack_uint2_values(
uint8_t* unpacked,
uint8_t* packed,
int unpacked_size,
int packed_size,
int variant) {
constexpr int nbit = 2;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);

uint8x8_t unpacked0_8x8;
uint8x8_t unpacked1_8x8;
uint8x8_t unpacked2_8x8;
uint8x8_t unpacked3_8x8;

uint8x16_t unpacked0_8x16;
uint8x16_t unpacked1_8x16;
uint8x16_t unpacked2_8x16;
uint8x16_t unpacked3_8x16;

switch (variant) {
case 4:
for (int i = 0; i < unpacked_size; i += 4) {
torchao::bitpacking::internal::unpack_4_uint2_values(
unpacked + i, packed + ((i * nbit) / bitsPerByte));
}
break;
case 32:
for (int i = 0; i < unpacked_size; i += 32) {
torchao::bitpacking::internal::vec_unpack_32_uint2_values(
unpacked0_8x8,
unpacked1_8x8,
unpacked2_8x8,
unpacked3_8x8,
packed + ((i * nbit) / bitsPerByte));
torchao::bitpacking::internal::vec_store_32_uint8_values(
unpacked + i,
unpacked0_8x8,
unpacked1_8x8,
unpacked2_8x8,
unpacked3_8x8);
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_unpack_64_uint2_values(
unpacked0_8x16,
unpacked1_8x16,
unpacked2_8x16,
unpacked3_8x16,
packed + ((i * nbit) / bitsPerByte));
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i,
unpacked0_8x16,
unpacked1_8x16,
unpacked2_8x16,
unpacked3_8x16);
}
break;
}
}

// Benchmark utility to compare variants of uint3 packing
void pack_uint3_values(
uint8_t* packed,
Expand Down Expand Up @@ -220,6 +348,44 @@ void unpack_uint4_values(

} // namespace

static void benchmark_pack_uint2_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
int nbit = 2;

assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);

for (auto _ : state) {
pack_uint2_values(
packed.data(), unpacked.data(), packed_size, unpacked_size, variant);
}
}

static void benchmark_unpack_uint2_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
int nbit = 2;

assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = torchao::get_random_lowbit_vector(packed_size, 8);
auto unpacked = std::vector<uint8_t>(unpacked_size, 0);

for (auto _ : state) {
unpack_uint2_values(
unpacked.data(),
packed.data(),
unpacked.size(),
packed.size(),
variant);
}
}

static void benchmark_pack_uint3_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
Expand Down Expand Up @@ -296,6 +462,8 @@ static void benchmark_unpack_uint4_values(benchmark::State& state) {
}
}

BENCHMARK(benchmark_pack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
BENCHMARK(benchmark_unpack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_unpack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_pack_uint4_values)->ArgsProduct({{128}, {2, 16, 32}});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,20 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
false>) \
->ArgsProduct(BENCHMARK_PARAMS)

BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
3);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
3);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
3);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
Expand Down
73 changes: 67 additions & 6 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once
#include <arm_neon.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
#include <cassert>
Expand All @@ -15,6 +16,30 @@ namespace torchao {
namespace bitpacking {

namespace internal {
TORCHAO_ALWAYS_INLINE inline void vec_store_32_uint8_values(
uint8_t* dest,
const uint8x8_t& vec0,
const uint8x8_t& vec1,
const uint8x8_t& vec2,
const uint8x8_t& vec3) {
vst1_u8(dest, vec0);
vst1_u8(dest + 8, vec1);
vst1_u8(dest + 16, vec2);
vst1_u8(dest + 24, vec3);
}

TORCHAO_ALWAYS_INLINE inline void vec_load_32_uint8_values(
uint8x8_t& vec0,
uint8x8_t& vec1,
uint8x8_t& vec2,
uint8x8_t& vec3,
const uint8_t* src) {
vec0 = vld1_u8(src);
vec1 = vld1_u8(src + 8);
vec2 = vld1_u8(src + 16);
vec3 = vld1_u8(src + 24);
}

TORCHAO_ALWAYS_INLINE inline void vec_store_64_uint8_values(
uint8_t* dest,
const uint8x16_t& vec0,
Expand Down Expand Up @@ -49,7 +74,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

// Shift unpacked values to nonnegative range
Expand All @@ -58,6 +83,13 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift));

switch (nbit) {
case 2:
torchao::bitpacking::internal::vec_pack_32_uint2_values(
packed,
vget_low_u8(shifted0),
vget_high_u8(shifted0),
vget_low_u8(shifted1),
vget_high_u8(shifted1));
case 3:
uint8_t buffer[32];
vst1q_u8(buffer, shifted0);
Expand Down Expand Up @@ -89,13 +121,22 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

uint8x16_t shifted0;
uint8x16_t shifted1;

switch (nbit) {
case 2:
uint8x8_t shifted0_low;
uint8x8_t shifted0_high;
uint8x8_t shifted1_low;
uint8x8_t shifted1_high;
torchao::bitpacking::internal::vec_unpack_32_uint2_values(
shifted0_low, shifted0_high, shifted1_low, shifted1_high, packed);
shifted0 = vcombine_u8(shifted0_low, shifted0_high);
shifted1 = vcombine_u8(shifted1_low, shifted1_high);
case 3:
uint8_t buffer[32];
torchao::bitpacking::internal::unpack_8_uint3_values(buffer, packed);
Expand Down Expand Up @@ -133,7 +174,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

// Shift unpacked values to nonnegative range
Expand All @@ -144,6 +185,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift));

switch (nbit) {
case 2:
torchao::bitpacking::internal::vec_pack_64_uint2_values(
packed, shifted0, shifted1, shifted2, shifted3);
break;
case 3:
torchao::bitpacking::internal::vec_pack_64_uint3_values(
packed, shifted0, shifted1, shifted2, shifted3);
Expand All @@ -170,7 +215,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

uint8x16_t shifted0;
Expand All @@ -179,6 +224,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
uint8x16_t shifted3;

switch (nbit) {
case 2:
torchao::bitpacking::internal::vec_unpack_64_uint2_values(
shifted0, shifted1, shifted2, shifted3, packed);
break;
case 3:
torchao::bitpacking::internal::vec_unpack_64_uint3_values(
shifted0, shifted1, shifted2, shifted3, packed);
Expand Down Expand Up @@ -216,7 +265,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

// Shift unpacked values to nonnegative range
Expand All @@ -231,6 +280,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
uint8x16_t shifted7 = vreinterpretq_u8_s8(vaddq_s8(unpacked7, shift));

switch (nbit) {
case 2:
torchao::bitpacking::internal::vec_pack_64_uint2_values(
packed, shifted0, shifted1, shifted2, shifted3);
torchao::bitpacking::internal::vec_pack_64_uint2_values(
packed + 16, shifted4, shifted5, shifted6, shifted7);
break;
case 3:
torchao::bitpacking::internal::vec_pack_128_uint3_values(
packed,
Expand Down Expand Up @@ -273,7 +328,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
static_assert(nbit >= 2);

// Currently supported values
static_assert(nbit >= 3);
static_assert(nbit >= 2);
static_assert(nbit <= 4);

uint8x16_t shifted0;
Expand All @@ -286,6 +341,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
uint8x16_t shifted7;

switch (nbit) {
case 2:
torchao::bitpacking::internal::vec_unpack_64_uint2_values(
shifted0, shifted1, shifted2, shifted3, packed);
torchao::bitpacking::internal::vec_unpack_64_uint2_values(
shifted4, shifted5, shifted6, shifted7, packed + 16);
break;
case 3:
torchao::bitpacking::internal::vec_unpack_128_uint3_values(
shifted0,
Expand Down
Loading
Loading