Skip to content

Commit

Permalink
Experimental 6-bit quantization for Llama in torchchat
Browse files Browse the repository at this point in the history
Differential Revision: D64437228

Pull Request resolved: #1094
  • Loading branch information
c4lcut3c authored Oct 16, 2024
1 parent ce4822b commit 893cafe
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 0 deletions.
177 changes: 177 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,183 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(b3210, 4));
}

TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2(
uint8_t* packed,
const uint8_t* unpacked) {
// Given 4 unpacked uint6 values: abcdef, ghijkl, mnopqr, 123456
// this function packs them as:
// packed[0]: 56 | abcdef
// packed[1]: 34 | ghijkl
// packed[2]: 12 | mnopqr
//
// Input is 4 bytes
// Output is 6 * 4 bits/8 = 3 bytes
packed[0] = unpacked[0];
packed[1] = unpacked[1];
packed[2] = unpacked[2];
// Last value is packed in the upper 2 bits of the three bytes
packed[0] |= ((unpacked[3] & 0b00'0011u) << 6);
packed[1] |= ((unpacked[3] & 0b00'1100u) << 4);
packed[2] |= ((unpacked[3] & 0b11'0000u) << 2);
}

TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2(
uint8_t* unpacked,
const uint8_t* packed) {
// Unpacks data packed by pack_4_uint6_values_v2
//
// Input is 24 bits = 3 bytes
// Output is 4 bytes
unpacked[0] = packed[0] & 0b111111u;
unpacked[1] = packed[1] & 0b111111u;
unpacked[2] = packed[2] & 0b111111u;
// Last value is packed in the upper 2 bits of the three bytes
unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
((packed[1] & 0b1100'0000u) >> 4) |
((packed[2] & 0b1100'0000u) >> 2);
}

TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2(
uint8_t* packed,
const uint8x16_t& unpacked0,
const uint8x16_t& unpacked1) {
// This function is a vectorized version of pack_4_uint6_values_v2.
// To understand the following code, please see pack_4_uint6_values_v2 first and
// consider the following mapping for the unpacked parameter of that function:
//
// unpacked[0] -> vget_low_u8(unpacked0)
// unpacked[1] -> vget_high_u8(unpacked0)
// unpacked[2] -> vget_low_u8(unpacked1)
// unpacked[3] -> vget_high_u8(unpacked1)
//
// Before each code section, there is a comment indicating the
// code in pack_4_uint6_values_v2 that is being vectorized.
//
// Input is 32 bytes.
// Output is 6*32= 192 bits = 24 bytes.
uint8x8_t r;

// packed[0] = unpacked[0]
// packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
r = vget_low_u8(unpacked0);
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'0011u)), 6));
vst1_u8(packed, r);

// packed[1] = unpacked[1]
// packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
r = vget_high_u8(unpacked0);
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'1100u)), 4));
vst1_u8(packed + 8, r);

// packed[2] = unpacked[2]
// packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
r = vget_low_u8(unpacked1);
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b11'0000u)), 2));
vst1_u8(packed + 16, r);
}

TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2(
uint8x16_t& unpacked0,
uint8x16_t& unpacked1,
const uint8_t* packed) {
// Unpacks data packed by vec_pack_32_uint6_values_v2.
//
// This function vectorizes unpack_4_uint6_values_v2.
// To understand it, please see unpack_4_uint6_values_v2 first.
// Before each code section, there is a comment indicating the
// code in unpack_4_uint6_values_v2 that is being vectorized.
//
// Input is 24 bytes.
// Output is 32 bytes.
uint8x8_t packed0 = vld1_u8(packed);
uint8x8_t packed1 = vld1_u8(packed + 8);
uint8x8_t packed2 = vld1_u8(packed + 16);

// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
// ((packed[1] & 0b1100'0000u) >> 4) |
// ((packed[2] & 0b1100'0000u) >> 2);
const uint8x8_t high = vdup_n_u8(0b1100'0000u);
uint8x8_t unpacked3;
unpacked3 = vorr_u8(vshr_n_u8(vand_u8(packed0, high), 6),
vshr_n_u8(vand_u8(packed1, high), 4));
unpacked3 = vorr_u8(unpacked3,
vshr_n_u8(vand_u8(packed2, high), 2));

// unpacked[i] = packed[i] & 0b11'1111u;
const uint8x8_t mask = vdup_n_u8(0b11'1111u);
unpacked0 = vcombine_u8(vand_u8(packed0, mask), vand_u8(packed1, mask));
unpacked1 = vcombine_u8(vand_u8(packed2, mask), unpacked3);
}

TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2(
uint8_t* packed,
const uint8x16_t& unpacked0,
const uint8x16_t& unpacked1,
const uint8x16_t& unpacked2,
const uint8x16_t& unpacked3) {
// This function is a vectorized version of pack_4_uint6_values_v2.
// To understand the following code, please see pack_4_uint6_values_v2 first.
// Before each code section, there is a comment indicating the
// code in pack_4_uint6_values_v2 that is being vectorized.
//
// Input is 48 bytes.
// Output is 64 bytes.
uint8x16_t r;

// packed[0] = unpacked[0]
// packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
r = unpacked0;
r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b00'0011u)), 6));
vst1q_u8(packed, r);

// packed[1] = unpacked[1]
// packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
r = unpacked1;
r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b00'1100u)), 4));
vst1q_u8(packed + 16, r);

// packed[2] = unpacked[2]
// packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
r = unpacked2;
r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b11'0000u)), 2));
vst1q_u8(packed + 32, r);
}

TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2(
uint8x16_t& unpacked0,
uint8x16_t& unpacked1,
uint8x16_t& unpacked2,
uint8x16_t& unpacked3,
const uint8_t* packed) {
// Unpacks data packed by vec_pack_64_uint6_values_v2.
//
// This function vectorizes unpack_4_uint6_values_v2.
// To understand it, please see unpack_4_uint6_values_v2 first.
// Before each code section, there is a comment indicating the
// code in unpack_4_uint6_values that is being vectorized

// Input is 48 bytes.
// Output is 64 bytes.
unpacked0 = vld1q_u8(packed);
unpacked1 = vld1q_u8(packed + 16);
unpacked2 = vld1q_u8(packed + 32);

// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
// ((packed[1] & 0b1100'0000u) >> 4) |
// ((packed[2] & 0b1100'0000u) >> 2);
const uint8x16_t high = vdupq_n_u8(0b1100'0000u);
unpacked3 = vorrq_u8(vshrq_n_u8(vandq_u8(unpacked0, high), 6),
vshrq_n_u8(vandq_u8(unpacked1, high), 4));
unpacked3 = vorrq_u8(unpacked3,
vshrq_n_u8(vandq_u8(unpacked2, high), 2));

// unpacked[i] = packed[i] & 0b11'1111u;
const uint8x16_t mask = vdupq_n_u8(0b11'1111u);
unpacked0 = vandq_u8(unpacked0, mask);
unpacked1 = vandq_u8(unpacked1, mask);
unpacked2 = vandq_u8(unpacked2, mask);
}

} // namespace internal
} // namespace bitpacking
} // namespace torchao
Expand Down
73 changes: 73 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,23 @@ TEST(test_bitpacking_4_uint6_values, PackUnpackAreSame) {
}
}

TEST(test_bitpacking_4_uint6_values_v2, PackUnpackAreSame) {
int unpacked_bytes = 4;
int packed_bytes = 3;
auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6);
std::vector<uint8_t> packed(packed_bytes, 0);
std::vector<uint8_t> unpacked(unpacked_bytes, 0);

torchao::bitpacking::internal::pack_4_uint6_values_v2(
packed.data(), input.data());
torchao::bitpacking::internal::unpack_4_uint6_values_v2(
unpacked.data(), packed.data());
for (int i = 0; i < unpacked_bytes; ++i) {
EXPECT_EQ(input[i], unpacked[i]);
}
}


TEST(test_bitpacking_32_uint6_values, PackUnpackAreSame) {
int unpacked_bytes = 32;
int packed_bytes = 24;
Expand All @@ -529,6 +546,31 @@ TEST(test_bitpacking_32_uint6_values, PackUnpackAreSame) {
}
}

TEST(test_bitpacking_32_uint6_values_v2, PackUnpackAreSame) {
int unpacked_bytes = 32;
int packed_bytes = 24;
auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6);
std::vector<uint8_t> packed(packed_bytes, 0);

uint8x16_t input0;
uint8x16_t input1;

uint8x16_t unpacked0;
uint8x16_t unpacked1;

input0 = vld1q_u8(input.data());
input1 = vld1q_u8(input.data() + 16);
torchao::bitpacking::internal::vec_pack_32_uint6_values_v2(
packed.data(), input0, input1);
torchao::bitpacking::internal::vec_unpack_32_uint6_values_v2(
unpacked0, unpacked1, packed.data());

for (int i = 0; i < 16; ++i) {
EXPECT_EQ(input0[i], unpacked0[i]);
EXPECT_EQ(input1[i], unpacked1[i]);
}
}

TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) {
int unpacked_bytes = 64;
int packed_bytes = 48;
Expand Down Expand Up @@ -560,6 +602,37 @@ TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) {
}
}

TEST(test_bitpacking_64_uint6_values_v2, PackUnpackAreSame) {
int unpacked_bytes = 64;
int packed_bytes = 48;
auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6);
std::vector<uint8_t> packed(packed_bytes, 0);

uint8x16_t input0;
uint8x16_t input1;
uint8x16_t input2;
uint8x16_t input3;

uint8x16_t unpacked0;
uint8x16_t unpacked1;
uint8x16_t unpacked2;
uint8x16_t unpacked3;

torchao::bitpacking::internal::vec_load_64_uint8_values(
input0, input1, input2, input3, input.data());
torchao::bitpacking::internal::vec_pack_64_uint6_values_v2(
packed.data(), input0, input1, input2, input3);
torchao::bitpacking::internal::vec_unpack_64_uint6_values_v2(
unpacked0, unpacked1, unpacked2, unpacked3, packed.data());

for (int i = 0; i < 16; ++i) {
EXPECT_EQ(input0[i], unpacked0[i]);
EXPECT_EQ(input1[i], unpacked1[i]);
EXPECT_EQ(input2[i], unpacked2[i]);
EXPECT_EQ(input3[i], unpacked3[i]);
}
}

// Universal bitpacking tests
template <int nbit>
void test_bitpacking_32_lowbit_values() {
Expand Down

0 comments on commit 893cafe

Please sign in to comment.