Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kleidi 4b blockwise gemv prototype #997

Merged
merged 19 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
036b782
[experimental] simple script UX fixes
digantdesai Oct 2, 2024
c4b9f1e
[experimental][kleidi] Add build support
digantdesai Oct 2, 2024
4a85c4d
[experimental][kleidi] Add uConfig support for qb4w 1x4x32 neon dotprod
digantdesai Oct 2, 2024
49afa4a
[experimental][kleidi] Add a basic test - compiles
digantdesai Oct 2, 2024
569c069
[experimental][kleidi] Pin kleidiai repo
digantdesai Oct 8, 2024
fd1423f
[experimental][kleidi] Clean up pack.h
digantdesai Oct 8, 2024
c323fb1
[experimental][kleidi] Refactor interface header
digantdesai Oct 8, 2024
8aa27c4
[experimental][kleidi] Improve unit-tests
digantdesai Oct 8, 2024
44ca4de
[experimental][kleidi] move common functions to interface
digantdesai Oct 8, 2024
c272739
[experimental][kleidi] Add 1x8x32 neon dotprod kernel
digantdesai Oct 8, 2024
ee62be5
[experimental][kleidi] linter
digantdesai Oct 8, 2024
ee49c6e
[experimental][kleidi] Reduce template types for tests
digantdesai Oct 8, 2024
a905ec3
[experimental][kleidi] Add m>1 tests
digantdesai Oct 10, 2024
7429bea
[experimental][kleidi] rename bf16 weight scale flag
digantdesai Oct 10, 2024
f28e556
[experimental][kleidi] Build kernel tests in debug mode
digantdesai Oct 10, 2024
17f2b43
[experimental][kleidi] Add TODO tasks
digantdesai Oct 10, 2024
3049ded
[experimental][kleidi] Allow weight zeros to be a nullptr
digantdesai Oct 10, 2024
d4bb3ed
[experimental][kleidi] rebase fixes with int to size_t
digantdesai Oct 10, 2024
f6e22fb
[experimental][kleidi] compile-time preprocessor switch for kleidi tests
digantdesai Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[experimental][kleidi] move common functions to interface
  • Loading branch information
digantdesai committed Oct 11, 2024
commit 44ca4defb5decbad446fc8ebf4537f552bc742e7
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,17 @@
// LICENSE file in the root directory of this source tree.

#pragma once
#include <cassert>
#include <cstddef>
#include <limits>
#include <vector>

#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h>
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>

#include <kai/kai_common.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h>

#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>

namespace torchao::kernels::cpu::aarch64::kleidi {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the cpp compliance on using namespace like this? Just confirm that it is atleast c++17

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMake dictates we can assume c++17

namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {

using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;

namespace neon_dotprod_1x4x32 {
ukernel get_ukernel() {
return ukernel{
const Ukernel get_ukernel() {
digantdesai marked this conversation as resolved.
Show resolved Hide resolved
return Ukernel{
.get_m_step =
kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_n_step =
Expand All @@ -50,158 +41,79 @@ ukernel get_ukernel() {
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod};
}

size_t roundup(size_t a, size_t b) {
return ((a + b - 1) / b) * b;
}

int activation_data_size(int m, int k, int group_size) {
auto ukernel = get_ukernel();
auto lhs_packing = get_lhs_packing();
return lhs_packing.get_lhs_packed_size(
m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr());
(void) group_size; // unused
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k);
}

void prepare_activation_data(
void* activation_data,
// Inputs
int m,
int k,
// Ignored if has_weight_zeros = false
int group_size,
const float* activations) {
auto ukernel = get_ukernel();
auto lhs_pack = get_lhs_packing();

lhs_pack.run_lhs_pack(
(void) group_size; // unused
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
get_ukernel(),
activation_data,
m,
k,
ukernel.get_mr(),
ukernel.get_kr(),
ukernel.get_sr(),
/*m_index_start=*/0,
activations,
/*lhs_stride=*/k * sizeof(float),
activation_data);
activations);
}

int weight_data_size(int n, int k, int group_size) {
auto ukernel = get_ukernel();
auto rhs_pack = get_rhs_packing();
return rhs_pack.get_rhs_packed_size(
n,
k,
ukernel.get_nr(),
ukernel.get_kr(),
ukernel.get_sr(),
group_size,
kai_datatype::kai_dt_bf16);
}

inline uint16_t get_bf16_from_float(float f) {
uint16_t bf16;
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
memcpy(&bf16, &f, sizeof(uint16_t));
#else
const void* fp = reinterpret_cast<const void*>(
reinterpret_cast<uintptr_t>(&f) + sizeof(float) - sizeof(uint16_t));
memcpy(&bf16, fp, sizeof(uint16_t));
#endif // __BYTE_ORDER__
return bf16;
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size);
}

// TODO: move most of these functions in the parent namespace and take in
// ukernel as a parameter
void prepare_weight_data(
void* weight_data,
// Inputs
int n,
int k,
int group_size,
const int8_t* weight_qvals,
const float* weight_scales,
const int8_t* weight_zeros) {
// TODO - remove this constraint and pad when possible
assert(n % 2 == 0);

assert(group_size % 32 == 0);
assert(k % group_size == 0);

// Convert scales to bf16
// TODO SIMDify this
size_t n_groups = n * k / group_size;
auto weight_scales_bf16 = std::vector<uint16_t>(n_groups, 0);
for (size_t i = 0; i < n_groups; i++) {
assert(weight_zeros[i] == 0);
weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]);
}

// Prepack weights before packing
// TODO SIMDify this
auto packed_weight_qvals = std::vector<uint8_t>(n * k / 2, 0);
uint8_t wzp = 8;
for (size_t i = 0; i < n * k; i += 2) {
const uint8_t low = static_cast<uint8_t>(weight_qvals[i] + wzp);
const uint8_t high = static_cast<uint8_t>(weight_qvals[i+1] + wzp);
packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF));
}

// Parameters for packing
rhs_packing::qparams_t qparams{
.lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16};

auto ukernel = get_ukernel();
auto rhs_pack = get_rhs_packing();

rhs_pack.run_rhs_pack(
/*groups=*/1,
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
get_ukernel(),
weight_data,
n,
k,
ukernel.get_nr(),
ukernel.get_kr(),
ukernel.get_sr(),
group_size,
/*rhs=*/reinterpret_cast<const uint8_t*>(packed_weight_qvals.data()),
/*rhs_stride=*/roundup(k, 2) / 2,
/*bias=*/nullptr, // TODO fix APIs to move bias here
/*scale=*/reinterpret_cast<const uint16_t*>(weight_scales_bf16.data()),
/*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size),
/*rhs_packed=*/weight_data,
/*extra_bytes=*/0,
/*qparams=*/&qparams);
weight_qvals,
weight_scales,
weight_zeros);
}

void kernel(
// Outputs
float32_t* output,
// Inputs
int output_m_stride,
int m,
int n,
int k,
int group_size,
const void* weight_data,
const void* activation_data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this packed activaiton and packed weight? if so maybe worth naming such

// Not applied if nullptr
const float* bias,
// zeros if has_clamp = false
float clamp_min,
float clamp_max) {
assert(output_m_stride == n);
if (clamp_min == clamp_max && clamp_min == 0) {
clamp_min = std::numeric_limits<float_t>::lowest();
clamp_max = std::numeric_limits<float_t>::max();
}
auto ukernel = get_ukernel();
ukernel.run_matmul(
(void) bias; // unused - needs API fixing
assert(output_m_stride == n);
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
}

auto ukernel = get_ukernel();
ukernel.run_matmul(
m,
n,
k,
group_size,
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_col=*/sizeof(float),
/*dst_stride_row=*/ n * sizeof(float),
/*dst_stride_col=*/ sizeof(float),
clamp_min,
clamp_max);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,137 @@

#pragma once

#include <cstdint>
#include <cstddef>
#include <cstring>
#include <cassert>
#include <limits>
#include <vector>

#include <kai/kai_common.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h>

#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h>

namespace torchao::kernels::cpu::aarch64::kleidi {

// Helper functions
// TODO: find a better place for these?

size_t roundup(size_t a, size_t b) {
return ((a + b - 1) / b) * b;
}

uint16_t get_bf16_from_float(float f) {
uint16_t bf16;
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
memcpy(&bf16, &f, sizeof(uint16_t));
#else
const void* fp = reinterpret_cast<const void*>(
reinterpret_cast<uintptr_t>(&f) + sizeof(float) - sizeof(uint16_t));
memcpy(&bf16, fp, sizeof(uint16_t));
#endif // __BYTE_ORDER__
return bf16;
}

namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {

using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;
using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;

int activation_data_size(const Ukernel ukernel, int m, int k) {
auto lhs_packing = get_lhs_packing();
return lhs_packing.get_lhs_packed_size(
m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr());
}

void prepare_activation_data(
const Ukernel ukernel,
void* activation_data,
int m,
int k,
const float* activations) {
auto lhs_pack = get_lhs_packing();

lhs_pack.run_lhs_pack(
m,
k,
ukernel.get_mr(),
ukernel.get_kr(),
ukernel.get_sr(),
/*m_index_start=*/0,
activations,
/*lhs_stride=*/k * sizeof(float),
activation_data);
}

int weight_data_size(const Ukernel ukernel, int n, int k, int group_size) {
auto rhs_pack = get_rhs_packing();
return rhs_pack.get_rhs_packed_size(
n,
k,
ukernel.get_nr(),
ukernel.get_kr(),
ukernel.get_sr(),
group_size,
kai_datatype::kai_dt_bf16);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why bf16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the kleidi kernel keeps scales as bf16 to save space.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we asked Kleidi to do this :p

}

void prepare_weight_data(
const Ukernel ukernel,
void* weight_data,
int n,
int k,
int group_size,
const int8_t* weight_qvals,
const float* weight_scales,
const int8_t* weight_zeros) {
// TODO - remove this constraint and pad when possible
assert(n % 2 == 0);

assert(group_size % 32 == 0);
assert(k % group_size == 0);

// TODO SIMDify this
size_t n_groups = n * k / group_size;
auto weight_scales_bf16 = std::vector<uint16_t>(n_groups, 0);
for (size_t i = 0; i < n_groups; i++) {
assert(weight_zeros[i] == 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe assert weight_zeros is a nullptr or all of its entries are zero?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it be a nullptr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use nullptr to mean "no weight zeros", rather than create a thing of zeros.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking this only if !nullptr, else unused.

weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]);
}

// Prepack weights before packing
// TODO SIMDify this
auto packed_weight_qvals = std::vector<uint8_t>(n * k / 2, 0);
uint8_t wzp = 8;
for (size_t i = 0; i < n * k; i += 2) {
const uint8_t low = static_cast<uint8_t>(weight_qvals[i] + wzp);
const uint8_t high = static_cast<uint8_t>(weight_qvals[i+1] + wzp);
packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF));
}

// Parameters for packing
rhs_packing::qparams_t qparams{
.lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16};

auto rhs_pack = get_rhs_packing();

rhs_pack.run_rhs_pack(
/*groups=*/1,
n,
k,
ukernel.get_nr(),
ukernel.get_kr(),
ukernel.get_sr(),
group_size,
/*rhs=*/reinterpret_cast<const uint8_t*>(packed_weight_qvals.data()),
/*rhs_stride=*/roundup(k, 2) / 2,
/*bias=*/nullptr, // TODO fix APIs to move bias here
/*scale=*/reinterpret_cast<const uint16_t*>(weight_scales_bf16.data()),
/*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size),
/*rhs_packed=*/weight_data,
/*extra_bytes=*/0,
/*qparams=*/&qparams);
}

} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
} // namespace torchao::kernels::cpu::aarch64::kleidi