Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental 6-bit quantization for Llama in torchchat #1094

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,183 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(b3210, 4));
}

TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2(
uint8_t* packed,
const uint8_t* unpacked) {
// Given 4 unpacked uint6 values: abcdef, ghijkl, mnopqr, 123456
// this function packs them as:
// packed[0]: 56 | abcdef
// packed[1]: 34 | ghijkl
// packed[2]: 12 | mnopqr
//
// Input is 4 bytes
// Output is 6 * 4 bits/8 = 3 bytes
packed[0] = unpacked[0];
packed[1] = unpacked[1];
packed[2] = unpacked[2];
// Last value is packed in the upper 2 bits of the three bytes
packed[0] |= ((unpacked[3] & 0b00'0011u) << 6);
packed[1] |= ((unpacked[3] & 0b00'1100u) << 4);
packed[2] |= ((unpacked[3] & 0b11'0000u) << 2);
}

TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2(
uint8_t* unpacked,
const uint8_t* packed) {
// Unpacks data packed by pack_4_uint6_values_v2
//
// Input is 24 bits = 3 bytes
// Output is 4 bytes
unpacked[0] = packed[0] & 0b111111u;
unpacked[1] = packed[1] & 0b111111u;
unpacked[2] = packed[2] & 0b111111u;
// Last value is packed in the upper 2 bits of the three bytes
unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
((packed[1] & 0b1100'0000u) >> 4) |
((packed[2] & 0b1100'0000u) >> 2);
}

TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2(
uint8_t* packed,
const uint8x16_t& unpacked0,
const uint8x16_t& unpacked1) {
// This function is a vectorized version of pack_4_uint6_values_v2.
// To understand the following code, please see pack_4_uint6_values_v2 first and
// consider the following mapping for the unpacked parameter of that function:
//
// unpacked[0] -> vget_low_u8(unpacked0)
// unpacked[1] -> vget_high_u8(unpacked0)
// unpacked[2] -> vget_low_u8(unpacked1)
// unpacked[3] -> vget_high_u8(unpacked1)
//
// Before each code section, there is a comment indicating the
// code in pack_4_uint6_values_v2 that is being vectorized.
//
// Input is 32 bytes.
// Output is 6*32= 192 bits = 24 bytes.
uint8x8_t r;

// packed[0] = unpacked[0]
// packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
r = vget_low_u8(unpacked0);
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'0011u)), 6));
vst1_u8(packed, r);

// packed[1] = unpacked[1]
// packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
r = vget_high_u8(unpacked0);
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'1100u)), 4));
vst1_u8(packed + 8, r);

// packed[2] = unpacked[2]
// packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
r = vget_low_u8(unpacked1);
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b11'0000u)), 2));
vst1_u8(packed + 16, r);
}

TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2(
uint8x16_t& unpacked0,
uint8x16_t& unpacked1,
const uint8_t* packed) {
// Unpacks data packed by vec_pack_32_uint6_values_v2.
//
// This function vectorizes unpack_4_uint6_values_v2.
// To understand it, please see unpack_4_uint6_values_v2 first.
// Before each code section, there is a comment indicating the
// code in unpack_4_uint6_values_v2 that is being vectorized.
//
// Input is 24 bytes.
// Output is 32 bytes.
uint8x8_t packed0 = vld1_u8(packed);
uint8x8_t packed1 = vld1_u8(packed + 8);
uint8x8_t packed2 = vld1_u8(packed + 16);

// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
// ((packed[1] & 0b1100'0000u) >> 4) |
// ((packed[2] & 0b1100'0000u) >> 2);
const uint8x8_t high = vdup_n_u8(0b1100'0000u);
uint8x8_t unpacked3;
unpacked3 = vorr_u8(vshr_n_u8(vand_u8(packed0, high), 6),
vshr_n_u8(vand_u8(packed1, high), 4));
unpacked3 = vorr_u8(unpacked3,
vshr_n_u8(vand_u8(packed2, high), 2));

// unpacked[i] = packed[i] & 0b11'1111u;
const uint8x8_t mask = vdup_n_u8(0b11'1111u);
unpacked0 = vcombine_u8(vand_u8(packed0, mask), vand_u8(packed1, mask));
unpacked1 = vcombine_u8(vand_u8(packed2, mask), unpacked3);
}

TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2(
uint8_t* packed,
const uint8x16_t& unpacked0,
const uint8x16_t& unpacked1,
const uint8x16_t& unpacked2,
const uint8x16_t& unpacked3) {
// This function is a vectorized version of pack_4_uint6_values_v2.
// To understand the following code, please see pack_4_uint6_values_v2 first.
// Before each code section, there is a comment indicating the
// code in pack_4_uint6_values_v2 that is being vectorized.
//
// Input is 48 bytes.
// Output is 64 bytes.
uint8x16_t r;

// packed[0] = unpacked[0]
// packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
r = unpacked0;
r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b00'0011u)), 6));
vst1q_u8(packed, r);

// packed[1] = unpacked[1]
// packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
r = unpacked1;
r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b00'1100u)), 4));
vst1q_u8(packed + 16, r);

// packed[2] = unpacked[2]
// packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
r = unpacked2;
r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b11'0000u)), 2));
vst1q_u8(packed + 32, r);
}

TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2(
uint8x16_t& unpacked0,
uint8x16_t& unpacked1,
uint8x16_t& unpacked2,
uint8x16_t& unpacked3,
const uint8_t* packed) {
// Unpacks data packed by vec_pack_64_uint6_values_v2.
//
// This function vectorizes unpack_4_uint6_values_v2.
// To understand it, please see unpack_4_uint6_values_v2 first.
// Before each code section, there is a comment indicating the
// code in unpack_4_uint6_values that is being vectorized

// Input is 48 bytes.
// Output is 64 bytes.
unpacked0 = vld1q_u8(packed);
unpacked1 = vld1q_u8(packed + 16);
unpacked2 = vld1q_u8(packed + 32);

// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
// ((packed[1] & 0b1100'0000u) >> 4) |
// ((packed[2] & 0b1100'0000u) >> 2);
const uint8x16_t high = vdupq_n_u8(0b1100'0000u);
unpacked3 = vorrq_u8(vshrq_n_u8(vandq_u8(unpacked0, high), 6),
vshrq_n_u8(vandq_u8(unpacked1, high), 4));
unpacked3 = vorrq_u8(unpacked3,
vshrq_n_u8(vandq_u8(unpacked2, high), 2));

// unpacked[i] = packed[i] & 0b11'1111u;
const uint8x16_t mask = vdupq_n_u8(0b11'1111u);
unpacked0 = vandq_u8(unpacked0, mask);
unpacked1 = vandq_u8(unpacked1, mask);
unpacked2 = vandq_u8(unpacked2, mask);
}

} // namespace internal
} // namespace bitpacking
} // namespace torchao
Expand Down
73 changes: 73 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,23 @@ TEST(test_bitpacking_4_uint6_values, PackUnpackAreSame) {
}
}

TEST(test_bitpacking_4_uint6_values_v2, PackUnpackAreSame) {
int unpacked_bytes = 4;
int packed_bytes = 3;
auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6);
std::vector<uint8_t> packed(packed_bytes, 0);
std::vector<uint8_t> unpacked(unpacked_bytes, 0);

torchao::bitpacking::internal::pack_4_uint6_values_v2(
packed.data(), input.data());
torchao::bitpacking::internal::unpack_4_uint6_values_v2(
unpacked.data(), packed.data());
for (int i = 0; i < unpacked_bytes; ++i) {
EXPECT_EQ(input[i], unpacked[i]);
}
}


TEST(test_bitpacking_32_uint6_values, PackUnpackAreSame) {
int unpacked_bytes = 32;
int packed_bytes = 24;
Expand All @@ -529,6 +546,31 @@ TEST(test_bitpacking_32_uint6_values, PackUnpackAreSame) {
}
}

TEST(test_bitpacking_32_uint6_values_v2, PackUnpackAreSame) {
int unpacked_bytes = 32;
int packed_bytes = 24;
auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6);
std::vector<uint8_t> packed(packed_bytes, 0);

uint8x16_t input0;
uint8x16_t input1;

uint8x16_t unpacked0;
uint8x16_t unpacked1;

input0 = vld1q_u8(input.data());
input1 = vld1q_u8(input.data() + 16);
torchao::bitpacking::internal::vec_pack_32_uint6_values_v2(
packed.data(), input0, input1);
torchao::bitpacking::internal::vec_unpack_32_uint6_values_v2(
unpacked0, unpacked1, packed.data());

for (int i = 0; i < 16; ++i) {
EXPECT_EQ(input0[i], unpacked0[i]);
EXPECT_EQ(input1[i], unpacked1[i]);
}
}

TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) {
int unpacked_bytes = 64;
int packed_bytes = 48;
Expand Down Expand Up @@ -560,6 +602,37 @@ TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) {
}
}

TEST(test_bitpacking_64_uint6_values_v2, PackUnpackAreSame) {
int unpacked_bytes = 64;
int packed_bytes = 48;
auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6);
std::vector<uint8_t> packed(packed_bytes, 0);

uint8x16_t input0;
uint8x16_t input1;
uint8x16_t input2;
uint8x16_t input3;

uint8x16_t unpacked0;
uint8x16_t unpacked1;
uint8x16_t unpacked2;
uint8x16_t unpacked3;

torchao::bitpacking::internal::vec_load_64_uint8_values(
input0, input1, input2, input3, input.data());
torchao::bitpacking::internal::vec_pack_64_uint6_values_v2(
packed.data(), input0, input1, input2, input3);
torchao::bitpacking::internal::vec_unpack_64_uint6_values_v2(
unpacked0, unpacked1, unpacked2, unpacked3, packed.data());

for (int i = 0; i < 16; ++i) {
EXPECT_EQ(input0[i], unpacked0[i]);
EXPECT_EQ(input1[i], unpacked1[i]);
EXPECT_EQ(input2[i], unpacked2[i]);
EXPECT_EQ(input3[i], unpacked3[i]);
}
}

// Universal bitpacking tests
template <int nbit>
void test_bitpacking_32_lowbit_values() {
Expand Down