-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kleidi 4b blockwise gemv prototype #997
Changes from 1 commit
036b782
c4b9f1e
4a85c4d
49afa4a
569c069
fd1423f
c323fb1
8aa27c4
44ca4de
c272739
ee62be5
ee49c6e
a905ec3
7429bea
f28e556
17f2b43
3049ded
d4bb3ed
f6e22fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,26 +6,17 @@ | |
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
#include <cassert> | ||
#include <cstddef> | ||
#include <limits> | ||
#include <vector> | ||
|
||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h> | ||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h> | ||
|
||
#include <kai/kai_common.h> | ||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h> | ||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h> | ||
|
||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h> | ||
|
||
namespace torchao::kernels::cpu::aarch64::kleidi { | ||
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { | ||
|
||
using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; | ||
|
||
namespace neon_dotprod_1x4x32 { | ||
ukernel get_ukernel() { | ||
return ukernel{ | ||
const Ukernel get_ukernel() { | ||
digantdesai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Ukernel{ | ||
.get_m_step = | ||
kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.get_n_step = | ||
|
@@ -50,158 +41,79 @@ ukernel get_ukernel() { | |
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; | ||
} | ||
|
||
size_t roundup(size_t a, size_t b) { | ||
return ((a + b - 1) / b) * b; | ||
} | ||
|
||
int activation_data_size(int m, int k, int group_size) { | ||
auto ukernel = get_ukernel(); | ||
auto lhs_packing = get_lhs_packing(); | ||
return lhs_packing.get_lhs_packed_size( | ||
m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); | ||
(void) group_size; // unused | ||
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k); | ||
} | ||
|
||
void prepare_activation_data( | ||
void* activation_data, | ||
// Inputs | ||
int m, | ||
int k, | ||
// Ignored if has_weight_zeros = false | ||
int group_size, | ||
const float* activations) { | ||
auto ukernel = get_ukernel(); | ||
auto lhs_pack = get_lhs_packing(); | ||
|
||
lhs_pack.run_lhs_pack( | ||
(void) group_size; // unused | ||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( | ||
get_ukernel(), | ||
activation_data, | ||
m, | ||
k, | ||
ukernel.get_mr(), | ||
ukernel.get_kr(), | ||
ukernel.get_sr(), | ||
/*m_index_start=*/0, | ||
activations, | ||
/*lhs_stride=*/k * sizeof(float), | ||
activation_data); | ||
activations); | ||
} | ||
|
||
int weight_data_size(int n, int k, int group_size) { | ||
auto ukernel = get_ukernel(); | ||
auto rhs_pack = get_rhs_packing(); | ||
return rhs_pack.get_rhs_packed_size( | ||
n, | ||
k, | ||
ukernel.get_nr(), | ||
ukernel.get_kr(), | ||
ukernel.get_sr(), | ||
group_size, | ||
kai_datatype::kai_dt_bf16); | ||
} | ||
|
||
inline uint16_t get_bf16_from_float(float f) { | ||
uint16_t bf16; | ||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | ||
memcpy(&bf16, &f, sizeof(uint16_t)); | ||
#else | ||
const void* fp = reinterpret_cast<const void*>( | ||
reinterpret_cast<uintptr_t>(&f) + sizeof(float) - sizeof(uint16_t)); | ||
memcpy(&bf16, fp, sizeof(uint16_t)); | ||
#endif // __BYTE_ORDER__ | ||
return bf16; | ||
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size); | ||
} | ||
|
||
// TODO: move most of these functions in the parent namespace and take in | ||
// ukernel as a parameter | ||
void prepare_weight_data( | ||
void* weight_data, | ||
// Inputs | ||
int n, | ||
int k, | ||
int group_size, | ||
const int8_t* weight_qvals, | ||
const float* weight_scales, | ||
const int8_t* weight_zeros) { | ||
// TODO - remove this constraint and pad when possible | ||
assert(n % 2 == 0); | ||
|
||
assert(group_size % 32 == 0); | ||
assert(k % group_size == 0); | ||
|
||
// Convert scales to bf16 | ||
// TODO SIMDify this | ||
size_t n_groups = n * k / group_size; | ||
auto weight_scales_bf16 = std::vector<uint16_t>(n_groups, 0); | ||
for (size_t i = 0; i < n_groups; i++) { | ||
assert(weight_zeros[i] == 0); | ||
weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]); | ||
} | ||
|
||
// Prepack weights before packing | ||
// TODO SIMDify this | ||
auto packed_weight_qvals = std::vector<uint8_t>(n * k / 2, 0); | ||
uint8_t wzp = 8; | ||
for (size_t i = 0; i < n * k; i += 2) { | ||
const uint8_t low = static_cast<uint8_t>(weight_qvals[i] + wzp); | ||
const uint8_t high = static_cast<uint8_t>(weight_qvals[i+1] + wzp); | ||
packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF)); | ||
} | ||
|
||
// Parameters for packing | ||
rhs_packing::qparams_t qparams{ | ||
.lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16}; | ||
|
||
auto ukernel = get_ukernel(); | ||
auto rhs_pack = get_rhs_packing(); | ||
|
||
rhs_pack.run_rhs_pack( | ||
/*groups=*/1, | ||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( | ||
get_ukernel(), | ||
weight_data, | ||
n, | ||
k, | ||
ukernel.get_nr(), | ||
ukernel.get_kr(), | ||
ukernel.get_sr(), | ||
group_size, | ||
/*rhs=*/reinterpret_cast<const uint8_t*>(packed_weight_qvals.data()), | ||
/*rhs_stride=*/roundup(k, 2) / 2, | ||
/*bias=*/nullptr, // TODO fix APIs to move bias here | ||
/*scale=*/reinterpret_cast<const uint16_t*>(weight_scales_bf16.data()), | ||
/*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size), | ||
/*rhs_packed=*/weight_data, | ||
/*extra_bytes=*/0, | ||
/*qparams=*/&qparams); | ||
weight_qvals, | ||
weight_scales, | ||
weight_zeros); | ||
} | ||
|
||
void kernel( | ||
// Outputs | ||
float32_t* output, | ||
// Inputs | ||
int output_m_stride, | ||
int m, | ||
int n, | ||
int k, | ||
int group_size, | ||
const void* weight_data, | ||
const void* activation_data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this packed activaiton and packed weight? if so maybe worth naming such |
||
// Not applied if nullptr | ||
const float* bias, | ||
// zeros if has_clamp = false | ||
float clamp_min, | ||
float clamp_max) { | ||
assert(output_m_stride == n); | ||
if (clamp_min == clamp_max && clamp_min == 0) { | ||
clamp_min = std::numeric_limits<float_t>::lowest(); | ||
clamp_max = std::numeric_limits<float_t>::max(); | ||
} | ||
auto ukernel = get_ukernel(); | ||
ukernel.run_matmul( | ||
(void) bias; // unused - needs API fixing | ||
assert(output_m_stride == n); | ||
if (clamp_min == 0 && clamp_max == 0) { | ||
clamp_min = std::numeric_limits<float>::lowest(); | ||
clamp_max = std::numeric_limits<float>::max(); | ||
} | ||
|
||
auto ukernel = get_ukernel(); | ||
ukernel.run_matmul( | ||
m, | ||
n, | ||
k, | ||
group_size, | ||
activation_data, | ||
weight_data, | ||
output, | ||
/*dst_stride_row=*/n * sizeof(float), | ||
/*dst_stride_col=*/sizeof(float), | ||
/*dst_stride_row=*/ n * sizeof(float), | ||
/*dst_stride_col=*/ sizeof(float), | ||
clamp_min, | ||
clamp_max); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,12 +7,137 @@ | |
|
||
#pragma once | ||
|
||
#include <cstdint> | ||
#include <cstddef> | ||
#include <cstring> | ||
#include <cassert> | ||
#include <limits> | ||
#include <vector> | ||
|
||
#include <kai/kai_common.h> | ||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h> | ||
|
||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h> | ||
|
||
namespace torchao::kernels::cpu::aarch64::kleidi { | ||
|
||
// Helper functions | ||
// TODO: find a better place for these? | ||
|
||
size_t roundup(size_t a, size_t b) { | ||
return ((a + b - 1) / b) * b; | ||
} | ||
|
||
uint16_t get_bf16_from_float(float f) { | ||
uint16_t bf16; | ||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | ||
memcpy(&bf16, &f, sizeof(uint16_t)); | ||
#else | ||
const void* fp = reinterpret_cast<const void*>( | ||
reinterpret_cast<uintptr_t>(&f) + sizeof(float) - sizeof(uint16_t)); | ||
memcpy(&bf16, fp, sizeof(uint16_t)); | ||
#endif // __BYTE_ORDER__ | ||
return bf16; | ||
} | ||
|
||
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { | ||
|
||
using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; | ||
using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; | ||
|
||
int activation_data_size(const Ukernel ukernel, int m, int k) { | ||
auto lhs_packing = get_lhs_packing(); | ||
return lhs_packing.get_lhs_packed_size( | ||
m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); | ||
} | ||
|
||
void prepare_activation_data( | ||
const Ukernel ukernel, | ||
void* activation_data, | ||
int m, | ||
int k, | ||
const float* activations) { | ||
auto lhs_pack = get_lhs_packing(); | ||
|
||
lhs_pack.run_lhs_pack( | ||
m, | ||
k, | ||
ukernel.get_mr(), | ||
ukernel.get_kr(), | ||
ukernel.get_sr(), | ||
/*m_index_start=*/0, | ||
activations, | ||
/*lhs_stride=*/k * sizeof(float), | ||
activation_data); | ||
} | ||
|
||
int weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { | ||
auto rhs_pack = get_rhs_packing(); | ||
return rhs_pack.get_rhs_packed_size( | ||
n, | ||
k, | ||
ukernel.get_nr(), | ||
ukernel.get_kr(), | ||
ukernel.get_sr(), | ||
group_size, | ||
kai_datatype::kai_dt_bf16); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why bf16? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the kleidi kernel keeps scales as bf16 to save space. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah we asked Kleidi to do this :p |
||
} | ||
|
||
void prepare_weight_data( | ||
const Ukernel ukernel, | ||
void* weight_data, | ||
int n, | ||
int k, | ||
int group_size, | ||
const int8_t* weight_qvals, | ||
const float* weight_scales, | ||
const int8_t* weight_zeros) { | ||
// TODO - remove this constraint and pad when possible | ||
assert(n % 2 == 0); | ||
|
||
assert(group_size % 32 == 0); | ||
assert(k % group_size == 0); | ||
|
||
// TODO SIMDify this | ||
size_t n_groups = n * k / group_size; | ||
auto weight_scales_bf16 = std::vector<uint16_t>(n_groups, 0); | ||
for (size_t i = 0; i < n_groups; i++) { | ||
assert(weight_zeros[i] == 0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe assert weight_zeros is a nullptr or all of its entries are zero? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can it be a nullptr? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use nullptr to mean "no weight zeros", rather than create a thing of zeros. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. checking this only if !nullptr, else unused. |
||
weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]); | ||
} | ||
|
||
// Prepack weights before packing | ||
// TODO SIMDify this | ||
auto packed_weight_qvals = std::vector<uint8_t>(n * k / 2, 0); | ||
uint8_t wzp = 8; | ||
for (size_t i = 0; i < n * k; i += 2) { | ||
const uint8_t low = static_cast<uint8_t>(weight_qvals[i] + wzp); | ||
const uint8_t high = static_cast<uint8_t>(weight_qvals[i+1] + wzp); | ||
packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF)); | ||
} | ||
|
||
// Parameters for packing | ||
rhs_packing::qparams_t qparams{ | ||
.lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16}; | ||
|
||
auto rhs_pack = get_rhs_packing(); | ||
|
||
rhs_pack.run_rhs_pack( | ||
/*groups=*/1, | ||
n, | ||
k, | ||
ukernel.get_nr(), | ||
ukernel.get_kr(), | ||
ukernel.get_sr(), | ||
group_size, | ||
/*rhs=*/reinterpret_cast<const uint8_t*>(packed_weight_qvals.data()), | ||
/*rhs_stride=*/roundup(k, 2) / 2, | ||
/*bias=*/nullptr, // TODO fix APIs to move bias here | ||
/*scale=*/reinterpret_cast<const uint16_t*>(weight_scales_bf16.data()), | ||
/*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size), | ||
/*rhs_packed=*/weight_data, | ||
/*extra_bytes=*/0, | ||
/*qparams=*/&qparams); | ||
} | ||
|
||
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p | ||
} // namespace torchao::kernels::cpu::aarch64::kleidi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the cpp compliance on using namespace like this? Just confirm that it is atleast c++17
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CMake dictates we can assume c++17