Skip to content

Commit

Permalink
[experimental][kleidi] Reduce template types for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
digantdesai committed Oct 8, 2024
1 parent 4271183 commit 7a766df
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 61 deletions.
46 changes: 24 additions & 22 deletions torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ TEST(
// #ifdef TORCHAO_ENABLE_KLEIDI
// TODO: Wire up the the compile defination for TORCHAO_ENABLE_KLEIDI

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
template <bool has_bias, bool has_clamp>
void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
int m,
int k,
Expand All @@ -369,8 +369,8 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
k,
n,
group_size,
weight_nbit,
has_weight_zeros,
/*weight_nbit=*/4,
/*has_weight_zeros*/false,
has_bias,
has_clamp,
/*weight_scale_bf16_round_trip=*/true);
Expand Down Expand Up @@ -421,8 +421,6 @@ TEST(
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
k_eq_gs_32) {
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32);
Expand All @@ -432,8 +430,6 @@ TEST(
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
large_k_n_gs32) {
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32);
Expand All @@ -443,8 +439,6 @@ TEST(
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
even_n_gs32) {
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32);
Expand All @@ -454,14 +448,21 @@ TEST(
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
k_eq_gs128) {
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128);
}

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
TEST(
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
clamp_k_eq_gs128) {
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
false /*has_bias*/,
true /*has_clamp*/>(
/*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128);
}

template <bool has_bias, bool has_clamp>
void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
int m,
int k,
Expand All @@ -473,8 +474,8 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
k,
n,
group_size,
weight_nbit,
has_weight_zeros,
/*weight_nbit=*/4,
/*has_weight_zeros=*/false,
has_bias,
has_clamp,
/*weight_scale_bf16_round_trip=*/true);
Expand Down Expand Up @@ -525,8 +526,6 @@ TEST(
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
k_eq_gs_32) {
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32);
Expand All @@ -536,8 +535,6 @@ TEST(
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
large_k_n_gs32) {
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32);
Expand All @@ -547,8 +544,6 @@ TEST(
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
even_n_gs32) {
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32);
Expand All @@ -558,11 +553,18 @@ TEST(
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
k_eq_gs128) {
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128);
}

TEST(
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
clamp_k_eq_gs128) {
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
false /*has_bias*/,
true /*has_clamp*/>(
/*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128);
}
// #endif // defined(TORCHAO_ENABLE_KLEIDI)
#endif // defined(__aarch64__) || defined(__ARM_NEON)
39 changes: 0 additions & 39 deletions torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case {
zero,
qmin,
qmax);
// std::fill(weight_qvals.begin(), weight_qvals.end(), -7);
}

std::vector<float> bias(m, 0.0);
Expand Down Expand Up @@ -277,44 +276,6 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case {
}
}

#if 0 // Alternate reference implementation for debugging.
auto num_groups = k / weight_group_size;
for (int m_idx = 0; m_idx < m; m_idx++) {
for (int n_idx = 0; n_idx < n; n_idx++) {
int32_t result_idx = m_idx * n + n_idx;
float weights_fsum = 0.0;
for (int g_idx = 0; g_idx < num_groups; g_idx++) {
int32_t weights_qsum = 0;
int32_t acc_i32 = 0;
for (int k_idx = 0; k_idx < weight_group_size; k_idx++) {
const int32_t activation_idx = m_idx * k + g_idx * weight_group_size + k_idx;
const int32_t weight_idx = n_idx * k + g_idx * weight_group_size + k_idx;

const int32_t weight_qval = weight_qvals[weight_idx];
const int32_t activation_qval = activation_qvals[activation_idx];

weights_qsum += weight_qval;
acc_i32 += weight_qval * activation_qval;
}
// For each group, we have a weight scale
const int32_t weight_scale_idx = n_idx * num_groups + g_idx;
const float weight_scale = weight_scales[weight_scale_idx]; // already rounded trip to bf16
expected_output[result_idx] += (float) acc_i32 * weight_scales[weight_scale_idx];
weights_fsum += weights_qsum * weight_scale;
}
// For each output channel, we have an activation scale
const int32_t activation_zero_point = activation_zeros[m_idx];
const float activation_scale = activation_scales[m_idx];
expected_output[result_idx] -= activation_zero_point * weights_fsum;
expected_output[result_idx] *= activation_scale;
expected_output[result_idx] += bias[m_idx];
if (has_clamp) {
expected_output[result_idx] = std::min(std::max(expected_output[result_idx], clamp_min), clamp_max);
}
}
}
#endif

// Return test case
return channelwise_8bit_activation_groupwise_lowbit_weight_test_case(
m,
Expand Down

0 comments on commit 7a766df

Please sign in to comment.