Skip to content

Commit

Permalink
[experimental][kleidi] linter
Browse files Browse the repository at this point in the history
  • Loading branch information
digantdesai committed Oct 8, 2024
1 parent 8bad3e6 commit 61694eb
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// namespace example
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
Expand Down Expand Up @@ -42,8 +41,9 @@ const Ukernel get_ukernel() {
}

int activation_data_size(int m, int k, int group_size) {
(void) group_size; // unused
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k);
(void)group_size; // unused
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(
get_ukernel(), m, k);
}

void prepare_activation_data(
Expand All @@ -52,17 +52,14 @@ void prepare_activation_data(
int k,
int group_size,
const float* activations) {
(void) group_size; // unused
(void)group_size; // unused
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
get_ukernel(),
activation_data,
m,
k,
activations);
get_ukernel(), activation_data, m, k, activations);
}

int weight_data_size(int n, int k, int group_size) {
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size);
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(
get_ukernel(), n, k, group_size);
}

void prepare_weight_data(
Expand Down Expand Up @@ -96,24 +93,24 @@ void kernel(
const float* bias,
float clamp_min,
float clamp_max) {
(void) bias; // unused - needs API fixing
assert(output_m_stride == n);
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
}
(void)bias; // unused - needs API fixing
assert(output_m_stride == n);
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
}

auto ukernel = get_ukernel();
ukernel.run_matmul(
auto ukernel = get_ukernel();
ukernel.run_matmul(
m,
n,
k,
group_size,
activation_data,
weight_data,
output,
/*dst_stride_row=*/ n * sizeof(float),
/*dst_stride_col=*/ sizeof(float),
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// namespace example
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// namespace example
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
Expand All @@ -7,10 +6,10 @@

#pragma once

#include <cstdint>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <cassert>
#include <limits>
#include <vector>

Expand All @@ -25,7 +24,7 @@ namespace torchao::kernels::cpu::aarch64::kleidi {
// TODO: find a better place for these?

size_t roundup(size_t a, size_t b) {
return ((a + b - 1) / b) * b;
return ((a + b - 1) / b) * b;
}

uint16_t get_bf16_from_float(float f) {
Expand Down Expand Up @@ -111,13 +110,15 @@ void prepare_weight_data(
uint8_t wzp = 8;
for (size_t i = 0; i < n * k; i += 2) {
const uint8_t low = static_cast<uint8_t>(weight_qvals[i] + wzp);
const uint8_t high = static_cast<uint8_t>(weight_qvals[i+1] + wzp);
const uint8_t high = static_cast<uint8_t>(weight_qvals[i + 1] + wzp);
packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF));
}

// Parameters for packing
rhs_packing::qparams_t qparams{
.lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16};
.lhs_zero_point = 1,
.rhs_zero_point = wzp,
.scale_dt = kai_datatype::kai_dt_bf16};

auto rhs_pack = get_rhs_packing();

Expand All @@ -133,7 +134,7 @@ void prepare_weight_data(
/*rhs_stride=*/roundup(k, 2) / 2,
/*bias=*/nullptr, // TODO fix APIs to move bias here
/*scale=*/reinterpret_cast<const uint16_t*>(weight_scales_bf16.data()),
/*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size),
/*scale_stride=*/sizeof(uint16_t) * (roundup(k, group_size) / group_size),
/*rhs_packed=*/weight_data,
/*extra_bytes=*/0,
/*qparams=*/&qparams);
Expand Down
22 changes: 9 additions & 13 deletions torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

#if defined(__aarch64__) || defined(__ARM_NEON)

#include <vector>
#include <arm_neon.h>
#include <vector>

#include <gtest/gtest.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
Expand Down Expand Up @@ -375,10 +375,10 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
has_clamp,
/*weight_scale_bf16_round_trip=*/true);

using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32;
using namespace torchao::kernels::cpu::aarch64::kleidi::
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32;

std::vector<char> activation_data(
activation_data_size(m, k, group_size));
std::vector<char> activation_data(activation_data_size(m, k, group_size));

prepare_activation_data(
(void*)activation_data.data(),
Expand All @@ -387,8 +387,7 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
group_size,
test_case.activations.data());

std::vector<char> weight_data(
weight_data_size(n, k, group_size));
std::vector<char> weight_data(weight_data_size(n, k, group_size));

prepare_weight_data(
(void*)weight_data.data(),
Expand Down Expand Up @@ -462,8 +461,6 @@ TEST(
/*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128);
}



template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
int m,
Expand All @@ -482,10 +479,10 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
has_clamp,
/*weight_scale_bf16_round_trip=*/true);

using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;
using namespace torchao::kernels::cpu::aarch64::kleidi::
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;

std::vector<char> activation_data(
activation_data_size(m, k, group_size));
std::vector<char> activation_data(activation_data_size(m, k, group_size));

prepare_activation_data(
(void*)activation_data.data(),
Expand All @@ -494,8 +491,7 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
group_size,
test_case.activations.data());

std::vector<char> weight_data(
weight_data_size(n, k, group_size));
std::vector<char> weight_data(weight_data_size(n, k, group_size));

prepare_weight_data(
(void*)weight_data.data(),
Expand Down

0 comments on commit 61694eb

Please sign in to comment.