-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kleidi 4b blockwise gemv prototype #997
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/997
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f6e22fb with merge base 7038f8b (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
343f7d4
to
fc0cd6d
Compare
61694eb
to
a3a49c6
Compare
FetchContent_MakeAvailable(kleidiai) | ||
|
||
# Disabled by default. Force enable if we are on a suitable system. | ||
# TODO: Introduce ISA specific flags for i8mm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you leave it disabled by default until we benchmark it against existing kernel in torchchat? I want to make sure we don't regress torchchat perf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't wire it up at the op level, and we enable only for armv8 and we only have dotprod kernels so this should be OK. Before we add i8mm kernels we have to fix the CMake and also the op level wiring.
...kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h
Show resolved
Hide resolved
@@ -0,0 +1,124 @@ | |||
// Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is very similar to the 1x4x32 one above. Do you think it's possible to reuse some code? Same comment with next file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! I want to lean on you c++ experts 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you wanna do this as follow up thats also ok, but I do agree that it can probably be structured differently. e.g get_ukernel can be factored out to take type of the kernel as arg
// #ifdef TORCHAO_ENABLE_KLEIDI | ||
// TODO: Wire up the the compile defination for TORCHAO_ENABLE_KLEIDI | ||
|
||
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this templating needed for Kleidi?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will remove wight_nbit and has_weight_zeros.
has_bias is something we will need. And adding new tests for has_clamp :P
#endif // defined(__aarch64__) || defined(__ARM_NEON) | ||
|
||
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h> | ||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
protect with TORCHAO_ENABLE_KLEIDI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I can drop op level completely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dropped.
@@ -8,9 +8,11 @@ | |||
|
|||
#if defined(__aarch64__) || defined(__ARM_NEON) | |||
#include <torchao/experimental/kernels/cpu/aarch64/linear/linear.h> | |||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
protect with TORCHAO_ENABLE_KLEIDI
7a766df
to
b1e6f8e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mine are mostly nits at this point
# KleidiAI is an open-source library that provides optimized | ||
# performance-critical routines, also known as micro-kernels, for artificial | ||
# intelligence (AI) workloads tailored for Arm® CPUs. | ||
FetchContent_Declare(kleidiai |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add this as build time dependency instead of 3p-lib? Wait I guess gitlab?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean as opposed to a git submodule? Just to keep it simple for now.
|
||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h> | ||
|
||
namespace torchao::kernels::cpu::aarch64::kleidi { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the cpp compliance on using namespace like this? Just confirm that it is atleast c++17
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CMake dictates we can assume c++17
namespace neon_dotprod_1x8x32 { | ||
const Ukernel get_ukernel() { | ||
return Ukernel{ | ||
.get_m_step = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also what are m/n step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
namespace torchao::kernels::cpu::aarch64::kleidi { | ||
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { | ||
namespace neon_dotprod_1x8x32 { | ||
const Ukernel get_ukernel() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For furture: I presume you will have to parameterize this for different kernels?
Also would it make sense to structure this in a way that this function moves to kleidi?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to think some more to support (1) AoT/Runtime weight packing, (2) per cpu uArch based uKernel selection. These logic would dictate how this interface looks like. So did something minimal here for the "prototype" but agree we can improve.
int k, | ||
int group_size, | ||
const void* weight_data, | ||
const void* activation_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this packed activaiton and packed weight? if so maybe worth naming such
clamp_max); | ||
} | ||
|
||
size_t get_alignement() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's part of the high-level op interface. FYI, @digantdesai, a landing bootcamper diff [D63873383] renamed things to preferred alignment to address a BE/EE backlog task. So make sure you rebase and retest before landing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh this is for the op level which can come in the later diffs
@@ -0,0 +1,124 @@ | |||
// Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you wanna do this as follow up thats also ok, but I do agree that it can probably be structured differently. e.g get_ukernel can be factored out to take type of the kernel as arg
ukernel.get_kr(), | ||
ukernel.get_sr(), | ||
group_size, | ||
kai_datatype::kai_dt_bf16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why bf16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the kleidi kernel keeps scales as bf16 to save space.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we asked Kleidi to do this :p
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h> | ||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this should be behind TORCHAO_ENABLE_KLEIDI?
size_t n_groups = n * k / group_size; | ||
auto weight_scales_bf16 = std::vector<uint16_t>(n_groups, 0); | ||
for (size_t i = 0; i < n_groups; i++) { | ||
assert(weight_zeros[i] == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe assert weight_zeros is a nullptr or all of its entries are zero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can it be a nullptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use nullptr to mean "no weight zeros", rather than create a thing of zeros.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking this only if !nullptr, else unused.
const float* bias, | ||
float clamp_min, | ||
float clamp_max) { | ||
(void)bias; // unused - needs API fixing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That or it could be added in this wrapper after the ukernel.run_matmul call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I follow, can you elaborate? kleidi wants bias in weight packing not here.
b42fd69
to
e68a9e2
Compare
@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
e68a9e2
to
f9a68f9
Compare
@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
f9a68f9
to
f6e22fb
Compare
@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Differential Revision: D64194844 Pull Request resolved: #997
As I was looking through documentations, In `/docs/Models.md`, I noticed one relative link `docs/GGUF.md` has a typo, which should be `GGUF.md`, so I changed it.
This integrates a couple of neon dot prod Kleidi kernel with TorchAO GEMM lower level interface.
The op level wiring is not part of this PR.
All tests pass for both kernels with 1e-4 :)