diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp index 926d475239..fd39de608b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -6,9 +6,9 @@ #include #include -#include #include +#include #include #include #include @@ -16,6 +16,134 @@ namespace { +// Benchmark utility to compare variants of uint2 packing +void pack_uint2_values( + uint8_t* packed, + uint8_t* unpacked, + int packed_size, + int unpacked_size, + int variant) { + constexpr int nbit = 2; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(packed_size % variant == 0); + + uint8x8_t unpacked0_8x8; + uint8x8_t unpacked1_8x8; + uint8x8_t unpacked2_8x8; + uint8x8_t unpacked3_8x8; + + uint8x16_t unpacked0_8x16; + uint8x16_t unpacked1_8x16; + uint8x16_t unpacked2_8x16; + uint8x16_t unpacked3_8x16; + + switch (variant) { + case 4: + for (int i = 0; i < unpacked_size; i += 4) { + torchao::bitpacking::internal::pack_4_uint2_values( + packed + ((i * nbit) / bitsPerByte), unpacked + i); + } + break; + case 32: + for (int i = 0; i < unpacked_size; i += 32) { + torchao::bitpacking::internal::vec_load_32_uint8_values( + unpacked0_8x8, + unpacked1_8x8, + unpacked2_8x8, + unpacked3_8x8, + unpacked + i); + torchao::bitpacking::internal::vec_pack_32_uint2_values( + packed + ((i * nbit) / bitsPerByte), + unpacked0_8x8, + unpacked1_8x8, + unpacked2_8x8, + unpacked3_8x8); + } + break; + case 64: + for (int i = 0; i < unpacked_size; i += 64) { + torchao::bitpacking::internal::vec_load_64_uint8_values( + unpacked0_8x16, + unpacked1_8x16, + unpacked2_8x16, + unpacked3_8x16, + unpacked + i); + torchao::bitpacking::internal::vec_pack_64_uint2_values( + packed + ((i * nbit) / bitsPerByte), + unpacked0_8x16, + unpacked1_8x16, + unpacked2_8x16, + unpacked3_8x16); + } + break; + } +} + +// Benchmark utility to compare variants of uint2 packing +void unpack_uint2_values( + uint8_t* unpacked, + uint8_t* packed, + int unpacked_size, + int packed_size, + int variant) { + constexpr int nbit = 2; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(packed_size % variant == 0); + + uint8x8_t unpacked0_8x8; + uint8x8_t unpacked1_8x8; + uint8x8_t unpacked2_8x8; + uint8x8_t unpacked3_8x8; + + uint8x16_t unpacked0_8x16; + uint8x16_t unpacked1_8x16; + uint8x16_t unpacked2_8x16; + uint8x16_t unpacked3_8x16; + + switch (variant) { + case 4: + for (int i = 0; i < unpacked_size; i += 4) { + torchao::bitpacking::internal::unpack_4_uint2_values( + unpacked + i, packed + ((i * nbit) / bitsPerByte)); + } + break; + case 32: + for (int i = 0; i < unpacked_size; i += 32) { + torchao::bitpacking::internal::vec_unpack_32_uint2_values( + unpacked0_8x8, + unpacked1_8x8, + unpacked2_8x8, + unpacked3_8x8, + packed + ((i * nbit) / bitsPerByte)); + torchao::bitpacking::internal::vec_store_32_uint8_values( + unpacked + i, + unpacked0_8x8, + unpacked1_8x8, + unpacked2_8x8, + unpacked3_8x8); + } + break; + case 64: + for (int i = 0; i < unpacked_size; i += 64) { + torchao::bitpacking::internal::vec_unpack_64_uint2_values( + unpacked0_8x16, + unpacked1_8x16, + unpacked2_8x16, + unpacked3_8x16, + packed + ((i * nbit) / bitsPerByte)); + torchao::bitpacking::internal::vec_store_64_uint8_values( + unpacked + i, + unpacked0_8x16, + unpacked1_8x16, + unpacked2_8x16, + unpacked3_8x16); + } + break; + } +} + // Benchmark utility to compare variants of uint3 packing void pack_uint3_values( uint8_t* packed, @@ -220,6 +348,44 @@ void unpack_uint4_values( } // namespace +static void benchmark_pack_uint2_values(benchmark::State& state) { + int unpacked_size = state.range(0); + int variant = state.range(1); + int nbit = 2; + + assert(unpacked_size % 8 == 0); + int packed_size = (unpacked_size / 8) * nbit; + + auto packed = std::vector(unpacked_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + + for (auto _ : state) { + pack_uint2_values( + packed.data(), unpacked.data(), packed_size, unpacked_size, variant); + } +} + +static void benchmark_unpack_uint2_values(benchmark::State& state) { + int unpacked_size = state.range(0); + int variant = state.range(1); + int nbit = 2; + + assert(unpacked_size % 8 == 0); + int packed_size = (unpacked_size / 8) * nbit; + + auto packed = torchao::get_random_lowbit_vector(packed_size, 8); + auto unpacked = std::vector(unpacked_size, 0); + + for (auto _ : state) { + unpack_uint2_values( + unpacked.data(), + packed.data(), + unpacked.size(), + packed.size(), + variant); + } +} + static void benchmark_pack_uint3_values(benchmark::State& state) { int unpacked_size = state.range(0); int variant = state.range(1); @@ -296,6 +462,8 @@ static void benchmark_unpack_uint4_values(benchmark::State& state) { } } +BENCHMARK(benchmark_pack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}}); +BENCHMARK(benchmark_unpack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}}); BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}}); BENCHMARK(benchmark_unpack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}}); BENCHMARK(benchmark_pack_uint4_values)->ArgsProduct({{128}, {2, 16, 32}}); diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp index 8e3ec0516f..12f2b2bdb7 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp @@ -228,14 +228,20 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( false>) \ ->ArgsProduct(BENCHMARK_PARAMS) +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( + 2); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( 3); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( 4); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( + 2); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 3); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 4); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( + 2); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( 3); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h index fce5abba42..3503d7fef9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -7,6 +7,7 @@ #pragma once #include #include +#include #include #include #include @@ -15,6 +16,30 @@ namespace torchao { namespace bitpacking { namespace internal { +TORCHAO_ALWAYS_INLINE inline void vec_store_32_uint8_values( + uint8_t* dest, + const uint8x8_t& vec0, + const uint8x8_t& vec1, + const uint8x8_t& vec2, + const uint8x8_t& vec3) { + vst1_u8(dest, vec0); + vst1_u8(dest + 8, vec1); + vst1_u8(dest + 16, vec2); + vst1_u8(dest + 24, vec3); +} + +TORCHAO_ALWAYS_INLINE inline void vec_load_32_uint8_values( + uint8x8_t& vec0, + uint8x8_t& vec1, + uint8x8_t& vec2, + uint8x8_t& vec3, + const uint8_t* src) { + vec0 = vld1_u8(src); + vec1 = vld1_u8(src + 8); + vec2 = vld1_u8(src + 16); + vec3 = vld1_u8(src + 24); +} + TORCHAO_ALWAYS_INLINE inline void vec_store_64_uint8_values( uint8_t* dest, const uint8x16_t& vec0, @@ -49,7 +74,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( static_assert(nbit >= 2); // Currently supported values - static_assert(nbit >= 3); + static_assert(nbit >= 2); static_assert(nbit <= 4); // Shift unpacked values to nonnegative range @@ -58,6 +83,13 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); switch (nbit) { + case 2: + torchao::bitpacking::internal::vec_pack_32_uint2_values( + packed, + vget_low_u8(shifted0), + vget_high_u8(shifted0), + vget_low_u8(shifted1), + vget_high_u8(shifted1)); case 3: uint8_t buffer[32]; vst1q_u8(buffer, shifted0); @@ -89,13 +121,22 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( static_assert(nbit >= 2); // Currently supported values - static_assert(nbit >= 3); + static_assert(nbit >= 2); static_assert(nbit <= 4); uint8x16_t shifted0; uint8x16_t shifted1; switch (nbit) { + case 2: + uint8x8_t shifted0_low; + uint8x8_t shifted0_high; + uint8x8_t shifted1_low; + uint8x8_t shifted1_high; + torchao::bitpacking::internal::vec_unpack_32_uint2_values( + shifted0_low, shifted0_high, shifted1_low, shifted1_high, packed); + shifted0 = vcombine_u8(shifted0_low, shifted0_high); + shifted1 = vcombine_u8(shifted1_low, shifted1_high); case 3: uint8_t buffer[32]; torchao::bitpacking::internal::unpack_8_uint3_values(buffer, packed); @@ -133,7 +174,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( static_assert(nbit >= 2); // Currently supported values - static_assert(nbit >= 3); + static_assert(nbit >= 2); static_assert(nbit <= 4); // Shift unpacked values to nonnegative range @@ -144,6 +185,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift)); switch (nbit) { + case 2: + torchao::bitpacking::internal::vec_pack_64_uint2_values( + packed, shifted0, shifted1, shifted2, shifted3); + break; case 3: torchao::bitpacking::internal::vec_pack_64_uint3_values( packed, shifted0, shifted1, shifted2, shifted3); @@ -170,7 +215,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( static_assert(nbit >= 2); // Currently supported values - static_assert(nbit >= 3); + static_assert(nbit >= 2); static_assert(nbit <= 4); uint8x16_t shifted0; @@ -179,6 +224,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( uint8x16_t shifted3; switch (nbit) { + case 2: + torchao::bitpacking::internal::vec_unpack_64_uint2_values( + shifted0, shifted1, shifted2, shifted3, packed); + break; case 3: torchao::bitpacking::internal::vec_unpack_64_uint3_values( shifted0, shifted1, shifted2, shifted3, packed); @@ -216,7 +265,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( static_assert(nbit >= 2); // Currently supported values - static_assert(nbit >= 3); + static_assert(nbit >= 2); static_assert(nbit <= 4); // Shift unpacked values to nonnegative range @@ -231,6 +280,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( uint8x16_t shifted7 = vreinterpretq_u8_s8(vaddq_s8(unpacked7, shift)); switch (nbit) { + case 2: + torchao::bitpacking::internal::vec_pack_64_uint2_values( + packed, shifted0, shifted1, shifted2, shifted3); + torchao::bitpacking::internal::vec_pack_64_uint2_values( + packed + 16, shifted4, shifted5, shifted6, shifted7); + break; case 3: torchao::bitpacking::internal::vec_pack_128_uint3_values( packed, @@ -273,7 +328,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( static_assert(nbit >= 2); // Currently supported values - static_assert(nbit >= 3); + static_assert(nbit >= 2); static_assert(nbit <= 4); uint8x16_t shifted0; @@ -286,6 +341,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( uint8x16_t shifted7; switch (nbit) { + case 2: + torchao::bitpacking::internal::vec_unpack_64_uint2_values( + shifted0, shifted1, shifted2, shifted3, packed); + torchao::bitpacking::internal::vec_unpack_64_uint2_values( + shifted4, shifted5, shifted6, shifted7, packed + 16); + break; case 3: torchao::bitpacking::internal::vec_unpack_128_uint3_values( shifted0, diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h new file mode 100644 index 0000000000..985dfd9a72 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h @@ -0,0 +1,132 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +// This file contains bitpacking and unpacking methods for uint4. +// These are not inteded to be used outside of bitpacking directory. +// See bitpack.h for the interface. + +namespace torchao { +namespace bitpacking { +namespace internal { + +TORCHAO_ALWAYS_INLINE inline void pack_4_uint2_values( + uint8_t* packed, + const uint8_t* unpacked) { + // Input is 4 bytes + // Output is 1 bytes + + packed[0] = (unpacked[0] << 6) | (unpacked[1] << 4) | (unpacked[2] << 2) | + (unpacked[3]); +} + +TORCHAO_ALWAYS_INLINE inline void unpack_4_uint2_values( + uint8_t* unpacked, + const uint8_t* packed) { + // Input is 1 bytes + // Output is 4 bytes + unpacked[0] = (packed[0] & 192) >> 6; + unpacked[1] = (packed[0] & 48) >> 4; + unpacked[2] = (packed[0] & 12) >> 2; + unpacked[3] = (packed[0] & 3); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint2_values( + uint8_t* packed, + const uint8x8_t& unpacked0, + const uint8x8_t& unpacked1, + const uint8x8_t& unpacked2, + const uint8x8_t& unpacked3) { + // Input is 32 bytes + // Output is 8 bytes + + // Vectorize the following: + // packed[0] = (unpacked[0] << 6) | (unpacked[1] << 4) | (unpacked[2] << 2) | + // (unpacked[3]); + + uint8x8_t vec_packed; + vec_packed = vshl_n_u8(unpacked0, 6); + vec_packed = vorr_u8(vec_packed, vshl_n_u8(unpacked1, 4)); + vec_packed = vorr_u8(vec_packed, vshl_n_u8(unpacked2, 2)); + vec_packed = vorr_u8(vec_packed, unpacked3); + vst1_u8(packed, vec_packed); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint2_values( + uint8x8_t& unpacked0, + uint8x8_t& unpacked1, + uint8x8_t& unpacked2, + uint8x8_t& unpacked3, + const uint8_t* packed) { + // Input is 8 bytes + // Output is 32 bytes + + // Vectorize the following: + // unpacked[0] = (packed[0] & 192) >> 6; + // unpacked[1] = (packed[0] & 48) >> 4; + // unpacked[2] = (packed[0] & 12) >> 2; + // unpacked[3] = (packed[0] & 3); + + uint8x8_t vec_packed; + + vec_packed = vld1_u8(packed); + unpacked0 = vshr_n_u8(vand_u8(vec_packed, vdup_n_u8(192)), 6); + unpacked1 = vshr_n_u8(vand_u8(vec_packed, vdup_n_u8(48)), 4); + unpacked2 = vshr_n_u8(vand_u8(vec_packed, vdup_n_u8(12)), 2); + unpacked3 = vand_u8(vec_packed, vdup_n_u8(3)); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint2_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3) { + // Input is 64 bytes + // Output is 16 bytes + + // Vectorize the following: + // packed[0] = (unpacked[0] << 6) | (unpacked[1] << 4) | (unpacked[2] << 2) | + // (unpacked[3]); + + uint8x16_t vec_packed; + vec_packed = vshlq_n_u8(unpacked0, 6); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked1, 4)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked2, 2)); + vec_packed = vorrq_u8(vec_packed, unpacked3); + vst1q_u8(packed, vec_packed); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint2_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + const uint8_t* packed) { + // Input is 16 bytes + // Output is 64 bytes + + // Vectorize the following: + // unpacked[0] = (packed[0] & 192) >> 6; + // unpacked[1] = (packed[0] & 48) >> 4; + // unpacked[2] = (packed[0] & 12) >> 2; + // unpacked[3] = (packed[0] & 3); + + uint8x16_t vec_packed; + + vec_packed = vld1q_u8(packed); + unpacked0 = vshrq_n_u8(vandq_u8(vec_packed, vdupq_n_u8(192)), 6); + unpacked1 = vshrq_n_u8(vandq_u8(vec_packed, vdupq_n_u8(48)), 4); + unpacked2 = vshrq_n_u8(vandq_u8(vec_packed, vdupq_n_u8(12)), 2); + unpacked3 = vandq_u8(vec_packed, vdupq_n_u8(3)); +} + +} // namespace internal +} // namespace bitpacking +} // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp index 28a46f8e06..baf6044c25 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -7,11 +7,92 @@ #include #include #include +#include #include #include #include #include +TEST(test_bitpacking_4_uint2_values, PackUnpackAreSame) { + int unpacked_bytes = 4; + int packed_bytes = 1; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 2); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::bitpacking::internal::pack_4_uint2_values( + packed.data(), input.data()); + torchao::bitpacking::internal::unpack_4_uint2_values( + unpacked.data(), packed.data()); + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_32_uint2_values, PackUnpackAreSame) { + int unpacked_bytes = 32; + int packed_bytes = 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 2); + std::vector packed(packed_bytes, 0); + + uint8x8_t input0; + uint8x8_t input1; + uint8x8_t input2; + uint8x8_t input3; + + uint8x8_t unpacked0; + uint8x8_t unpacked1; + uint8x8_t unpacked2; + uint8x8_t unpacked3; + + torchao::bitpacking::internal::vec_load_32_uint8_values( + input0, input1, input2, input3, input.data()); + + torchao::bitpacking::internal::vec_pack_32_uint2_values( + packed.data(), input0, input1, input2, input3); + torchao::bitpacking::internal::vec_unpack_32_uint2_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + } +} + +TEST(test_bitpacking_64_uint2_values, PackUnpackAreSame) { + int unpacked_bytes = 64; + int packed_bytes = 16; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 2); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + + torchao::bitpacking::internal::vec_pack_64_uint2_values( + packed.data(), input0, input1, input2, input3); + torchao::bitpacking::internal::vec_unpack_64_uint2_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + } +} + TEST(test_bitpacking_8_uint3_values, PackUnpackAreSame) { int unpacked_bytes = 8; int packed_bytes = 3; @@ -189,8 +270,7 @@ template void test_bitpacking_32_lowbit_values() { int unpacked_bytes = 32; int packed_bytes = unpacked_bytes * nbit / 8; - auto input_shifted = - torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + auto input_shifted = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); std::vector input(unpacked_bytes, 0); int8_t low = -(1 << (nbit - 1)); int8_t high = (1 << (nbit - 1)); @@ -222,8 +302,7 @@ template void test_bitpacking_64_lowbit_values() { int unpacked_bytes = 64; int packed_bytes = unpacked_bytes * nbit / 8; - auto input_shifted = - torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + auto input_shifted = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); std::vector input(unpacked_bytes, 0); int8_t low = -(1 << (nbit - 1)); int8_t high = (1 << (nbit - 1)); @@ -263,8 +342,7 @@ template void test_bitpacking_128_lowbit_values() { int unpacked_bytes = 128; int packed_bytes = unpacked_bytes * nbit / 8; - auto input_shifted = - torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + auto input_shifted = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); std::vector input(unpacked_bytes, 0); int8_t low = -(1 << (nbit - 1)); int8_t high = (1 << (nbit - 1)); @@ -347,11 +425,14 @@ void test_bitpacking_128_lowbit_values() { test_bitpacking_128_lowbit_values(); \ } +TEST_BITPACKING_32_LOWBIT_VALUES(2); TEST_BITPACKING_32_LOWBIT_VALUES(3); TEST_BITPACKING_32_LOWBIT_VALUES(4); +TEST_BITPACKING_64_LOWBIT_VALUES(2); TEST_BITPACKING_64_LOWBIT_VALUES(3); TEST_BITPACKING_64_LOWBIT_VALUES(4); +TEST_BITPACKING_128_LOWBIT_VALUES(2); TEST_BITPACKING_128_LOWBIT_VALUES(3); TEST_BITPACKING_128_LOWBIT_VALUES(4);