Skip to content

Commit

Permalink
Use fewer instructions when unpacking uint6s.
Browse files Browse the repository at this point in the history
Differential Revision: D64548639

Pull Request resolved: pytorch#1109
  • Loading branch information
c4lcut3c authored Oct 17, 2024
1 parent 6653b45 commit 3475aed
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,20 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values(
uint8x8_t packed1 = vld1_u8(packed + 8);
uint8x8_t packed2 = vld1_u8(packed + 16);

// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
// ((packed[1] & 0b1100'0000u) >> 4) |
// ((packed[2] & 0b1100'0000u) >> 2);
const uint8x8_t high = vdup_n_u8(0b1100'0000u);
uint8x8_t unpacked3;
unpacked3 = vorr_u8(
vshr_n_u8(vand_u8(packed0, high), 6),
vshr_n_u8(vand_u8(packed1, high), 4));
unpacked3 = vorr_u8(unpacked3, vshr_n_u8(vand_u8(packed2, high), 2));
// We want to extract bits 123456 and place them in unpacked3.
// Packed structure is:
//
// packed0: 56 | abcdef
// packed1: 34 | ghijkl
// packed2: 12 | mnopqr
//
// unpacked3 = 1234 ghij
unpacked3 = vsri_n_u8(packed2, packed1, 2);
// unpacked3 = 1234 56ab
unpacked3 = vsri_n_u8(unpacked3, packed0, 4);
// unpacked3 = 0012 3456
unpacked3 = vshr_n_u8(unpacked3, 2);

// unpacked[i] = packed[i] & 0b11'1111u;
const uint8x8_t mask = vdup_n_u8(0b11'1111u);
Expand Down Expand Up @@ -183,14 +188,19 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
unpacked1 = vld1q_u8(packed + 16);
unpacked2 = vld1q_u8(packed + 32);

// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
// ((packed[1] & 0b1100'0000u) >> 4) |
// ((packed[2] & 0b1100'0000u) >> 2);
const uint8x16_t high = vdupq_n_u8(0b1100'0000u);
unpacked3 = vorrq_u8(
vshrq_n_u8(vandq_u8(unpacked0, high), 6),
vshrq_n_u8(vandq_u8(unpacked1, high), 4));
unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(vandq_u8(unpacked2, high), 2));
// We want to extract bits 123456 and place them in unpacked3.
// Packed structure is:
//
// packed0: 56 | abcdef
// packed1: 34 | ghijkl
// packed2: 12 | mnopqr
//
// unpacked3 = 1234 ghij
unpacked3 = vsriq_n_u8(unpacked2, unpacked1, 2);
// unpacked3 = 1234 56ab
unpacked3 = vsriq_n_u8(unpacked3, unpacked0, 4);
// unpacked3 = 0012 3456
unpacked3 = vshrq_n_u8(unpacked3, 2);

// unpacked[i] = packed[i] & 0b11'1111u;
const uint8x16_t mask = vdupq_n_u8(0b11'1111u);
Expand Down

0 comments on commit 3475aed

Please sign in to comment.