From 893cafe62b99ca4b5f34fe8df26ebc057f7714d2 Mon Sep 17 00:00:00 2001 From: c4lcut3c <97532828+c4lcut3c@users.noreply.github.com> Date: Wed, 16 Oct 2024 13:29:42 -0700 Subject: [PATCH] Experimental 6-bit quantization for Llama in torchchat Differential Revision: D64437228 Pull Request resolved: https://github.com/pytorch/ao/pull/1094 --- .../kernels/cpu/aarch64/bitpacking/uint6.h | 177 ++++++++++++++++++ .../cpu/aarch64/tests/test_bitpacking.cpp | 73 ++++++++ 2 files changed, 250 insertions(+) diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h index fd7535a022..c65974bbbb 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h @@ -234,6 +234,183 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values( unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(b3210, 4)); } +TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2( + uint8_t* packed, + const uint8_t* unpacked) { + // Given 4 unpacked uint6 values: abcdef, ghijkl, mnopqr, 123456 + // this function packs them as: + // packed[0]: 56 | abcdef + // packed[1]: 34 | ghijkl + // packed[2]: 12 | mnopqr + // + // Input is 4 bytes + // Output is 6 * 4 bits/8 = 3 bytes + packed[0] = unpacked[0]; + packed[1] = unpacked[1]; + packed[2] = unpacked[2]; + // Last value is packed in the upper 2 bits of the three bytes + packed[0] |= ((unpacked[3] & 0b00'0011u) << 6); + packed[1] |= ((unpacked[3] & 0b00'1100u) << 4); + packed[2] |= ((unpacked[3] & 0b11'0000u) << 2); +} + +TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2( + uint8_t* unpacked, + const uint8_t* packed) { + // Unpacks data packed by pack_4_uint6_values_v2 + // + // Input is 24 bits = 3 bytes + // Output is 4 bytes + unpacked[0] = packed[0] & 0b111111u; + unpacked[1] = packed[1] & 0b111111u; + unpacked[2] = packed[2] & 0b111111u; + // Last value is packed in the upper 2 bits of the three bytes + unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) | + ((packed[1] & 0b1100'0000u) >> 4) | + ((packed[2] & 0b1100'0000u) >> 2); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1) { + // This function is a vectorized version of pack_4_uint6_values_v2. + // To understand the following code, please see pack_4_uint6_values_v2 first and + // consider the following mapping for the unpacked parameter of that function: + // + // unpacked[0] -> vget_low_u8(unpacked0) + // unpacked[1] -> vget_high_u8(unpacked0) + // unpacked[2] -> vget_low_u8(unpacked1) + // unpacked[3] -> vget_high_u8(unpacked1) + // + // Before each code section, there is a comment indicating the + // code in pack_4_uint6_values_v2 that is being vectorized. + // + // Input is 32 bytes. + // Output is 6*32= 192 bits = 24 bytes. + uint8x8_t r; + + // packed[0] = unpacked[0] + // packed[0] |= ((unpacked[3] & 0b00'0011u) << 6) + r = vget_low_u8(unpacked0); + r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'0011u)), 6)); + vst1_u8(packed, r); + + // packed[1] = unpacked[1] + // packed[1] |= ((unpacked[3] & 0b00'1100u) << 4) + r = vget_high_u8(unpacked0); + r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'1100u)), 4)); + vst1_u8(packed + 8, r); + + // packed[2] = unpacked[2] + // packed[2] |= ((unpacked[3] & 0b11'0000u) << 2) + r = vget_low_u8(unpacked1); + r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b11'0000u)), 2)); + vst1_u8(packed + 16, r); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + const uint8_t* packed) { + // Unpacks data packed by vec_pack_32_uint6_values_v2. + // + // This function vectorizes unpack_4_uint6_values_v2. + // To understand it, please see unpack_4_uint6_values_v2 first. + // Before each code section, there is a comment indicating the + // code in unpack_4_uint6_values_v2 that is being vectorized. + // + // Input is 24 bytes. + // Output is 32 bytes. + uint8x8_t packed0 = vld1_u8(packed); + uint8x8_t packed1 = vld1_u8(packed + 8); + uint8x8_t packed2 = vld1_u8(packed + 16); + + // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) | + // ((packed[1] & 0b1100'0000u) >> 4) | + // ((packed[2] & 0b1100'0000u) >> 2); + const uint8x8_t high = vdup_n_u8(0b1100'0000u); + uint8x8_t unpacked3; + unpacked3 = vorr_u8(vshr_n_u8(vand_u8(packed0, high), 6), + vshr_n_u8(vand_u8(packed1, high), 4)); + unpacked3 = vorr_u8(unpacked3, + vshr_n_u8(vand_u8(packed2, high), 2)); + + // unpacked[i] = packed[i] & 0b11'1111u; + const uint8x8_t mask = vdup_n_u8(0b11'1111u); + unpacked0 = vcombine_u8(vand_u8(packed0, mask), vand_u8(packed1, mask)); + unpacked1 = vcombine_u8(vand_u8(packed2, mask), unpacked3); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3) { + // This function is a vectorized version of pack_4_uint6_values_v2. + // To understand the following code, please see pack_4_uint6_values_v2 first. + // Before each code section, there is a comment indicating the + // code in pack_4_uint6_values_v2 that is being vectorized. + // + // Input is 48 bytes. + // Output is 64 bytes. + uint8x16_t r; + + // packed[0] = unpacked[0] + // packed[0] |= ((unpacked[3] & 0b00'0011u) << 6) + r = unpacked0; + r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b00'0011u)), 6)); + vst1q_u8(packed, r); + + // packed[1] = unpacked[1] + // packed[1] |= ((unpacked[3] & 0b00'1100u) << 4) + r = unpacked1; + r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b00'1100u)), 4)); + vst1q_u8(packed + 16, r); + + // packed[2] = unpacked[2] + // packed[2] |= ((unpacked[3] & 0b11'0000u) << 2) + r = unpacked2; + r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b11'0000u)), 2)); + vst1q_u8(packed + 32, r); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + const uint8_t* packed) { + // Unpacks data packed by vec_pack_64_uint6_values_v2. + // + // This function vectorizes unpack_4_uint6_values_v2. + // To understand it, please see unpack_4_uint6_values_v2 first. + // Before each code section, there is a comment indicating the + // code in unpack_4_uint6_values that is being vectorized + + // Input is 48 bytes. + // Output is 64 bytes. + unpacked0 = vld1q_u8(packed); + unpacked1 = vld1q_u8(packed + 16); + unpacked2 = vld1q_u8(packed + 32); + + // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) | + // ((packed[1] & 0b1100'0000u) >> 4) | + // ((packed[2] & 0b1100'0000u) >> 2); + const uint8x16_t high = vdupq_n_u8(0b1100'0000u); + unpacked3 = vorrq_u8(vshrq_n_u8(vandq_u8(unpacked0, high), 6), + vshrq_n_u8(vandq_u8(unpacked1, high), 4)); + unpacked3 = vorrq_u8(unpacked3, + vshrq_n_u8(vandq_u8(unpacked2, high), 2)); + + // unpacked[i] = packed[i] & 0b11'1111u; + const uint8x16_t mask = vdupq_n_u8(0b11'1111u); + unpacked0 = vandq_u8(unpacked0, mask); + unpacked1 = vandq_u8(unpacked1, mask); + unpacked2 = vandq_u8(unpacked2, mask); +} + } // namespace internal } // namespace bitpacking } // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp index ef51fd7d43..434c45a379 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -504,6 +504,23 @@ TEST(test_bitpacking_4_uint6_values, PackUnpackAreSame) { } } +TEST(test_bitpacking_4_uint6_values_v2, PackUnpackAreSame) { + int unpacked_bytes = 4; + int packed_bytes = 3; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::bitpacking::internal::pack_4_uint6_values_v2( + packed.data(), input.data()); + torchao::bitpacking::internal::unpack_4_uint6_values_v2( + unpacked.data(), packed.data()); + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + + TEST(test_bitpacking_32_uint6_values, PackUnpackAreSame) { int unpacked_bytes = 32; int packed_bytes = 24; @@ -529,6 +546,31 @@ TEST(test_bitpacking_32_uint6_values, PackUnpackAreSame) { } } +TEST(test_bitpacking_32_uint6_values_v2, PackUnpackAreSame) { + int unpacked_bytes = 32; + int packed_bytes = 24; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + + input0 = vld1q_u8(input.data()); + input1 = vld1q_u8(input.data() + 16); + torchao::bitpacking::internal::vec_pack_32_uint6_values_v2( + packed.data(), input0, input1); + torchao::bitpacking::internal::vec_unpack_32_uint6_values_v2( + unpacked0, unpacked1, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + } +} + TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) { int unpacked_bytes = 64; int packed_bytes = 48; @@ -560,6 +602,37 @@ TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) { } } +TEST(test_bitpacking_64_uint6_values_v2, PackUnpackAreSame) { + int unpacked_bytes = 64; + int packed_bytes = 48; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + torchao::bitpacking::internal::vec_pack_64_uint6_values_v2( + packed.data(), input0, input1, input2, input3); + torchao::bitpacking::internal::vec_unpack_64_uint6_values_v2( + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + } +} + // Universal bitpacking tests template void test_bitpacking_32_lowbit_values() {