Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kleidi 4b blockwise gemv prototype #997

Merged
merged 19 commits into from
Oct 11, 2024
Merged

Conversation

digantdesai
Copy link
Contributor

@digantdesai digantdesai commented Oct 2, 2024

This integrates a couple of neon dot prod Kleidi kernel with TorchAO GEMM lower level interface.

The op level wiring is not part of this PR.

All tests pass for both kernels with 1e-4 :)

Copy link

pytorch-bot bot commented Oct 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/997

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f6e22fb with merge base 7038f8b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 2, 2024
@digantdesai digantdesai force-pushed the kleidi_prototype_gemv branch 2 times, most recently from 343f7d4 to fc0cd6d Compare October 8, 2024 01:53
@digantdesai digantdesai marked this pull request as ready for review October 8, 2024 03:48
@digantdesai digantdesai force-pushed the kleidi_prototype_gemv branch from 61694eb to a3a49c6 Compare October 8, 2024 03:51
FetchContent_MakeAvailable(kleidiai)

# Disabled by default. Force enable if we are on a suitable system.
# TODO: Introduce ISA specific flags for i8mm.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave it disabled by default until we benchmark it against existing kernel in torchchat? I want to make sure we don't regress torchchat perf.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't wire it up at the op level, and we enable only for armv8 and we only have dotprod kernels so this should be OK. Before we add i8mm kernels we have to fix the CMake and also the op level wiring.

@@ -0,0 +1,124 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is very similar to the 1x4x32 one above. Do you think it's possible to reuse some code? Same comment with next file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! I want to lean on you c++ experts 😅

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you wanna do this as follow up thats also ok, but I do agree that it can probably be structured differently. e.g get_ukernel can be factored out to take type of the kernel as arg

// #ifdef TORCHAO_ENABLE_KLEIDI
// TODO: Wire up the the compile defination for TORCHAO_ENABLE_KLEIDI

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this templating needed for Kleidi?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will remove wight_nbit and has_weight_zeros.

has_bias is something we will need. And adding new tests for has_clamp :P

#endif // defined(__aarch64__) || defined(__ARM_NEON)

#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protect with TORCHAO_ENABLE_KLEIDI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I can drop op level completely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropped.

@@ -8,9 +8,11 @@

#if defined(__aarch64__) || defined(__ARM_NEON)
#include <torchao/experimental/kernels/cpu/aarch64/linear/linear.h>
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protect with TORCHAO_ENABLE_KLEIDI

@digantdesai digantdesai force-pushed the kleidi_prototype_gemv branch 2 times, most recently from 7a766df to b1e6f8e Compare October 10, 2024 15:19
Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mine are mostly nits at this point

torchao/experimental/build_torchao_ops.sh Show resolved Hide resolved
# KleidiAI is an open-source library that provides optimized
# performance-critical routines, also known as micro-kernels, for artificial
# intelligence (AI) workloads tailored for Arm® CPUs.
FetchContent_Declare(kleidiai
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add this as build time dependency instead of 3p-lib? Wait I guess gitlab?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean as opposed to a git submodule? Just to keep it simple for now.


#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>

namespace torchao::kernels::cpu::aarch64::kleidi {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the cpp compliance on using namespace like this? Just confirm that it is atleast c++17

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMake dictates we can assume c++17

namespace neon_dotprod_1x8x32 {
const Ukernel get_ukernel() {
return Ukernel{
.get_m_step =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what are m/n step?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

namespace torchao::kernels::cpu::aarch64::kleidi {
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
namespace neon_dotprod_1x8x32 {
const Ukernel get_ukernel() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For furture: I presume you will have to parameterize this for different kernels?

Also would it make sense to structure this in a way that this function moves to kleidi?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to think some more to support (1) AoT/Runtime weight packing, (2) per cpu uArch based uKernel selection. These logic would dictate how this interface looks like. So did something minimal here for the "prototype" but agree we can improve.

int k,
int group_size,
const void* weight_data,
const void* activation_data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this packed activaiton and packed weight? if so maybe worth naming such

clamp_max);
}

size_t get_alignement() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's part of the high-level op interface. FYI, @digantdesai, a landing bootcamper diff [D63873383] renamed things to preferred alignment to address a BE/EE backlog task. So make sure you rebase and retest before landing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this is for the op level which can come in the later diffs

@@ -0,0 +1,124 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you wanna do this as follow up thats also ok, but I do agree that it can probably be structured differently. e.g get_ukernel can be factored out to take type of the kernel as arg

ukernel.get_kr(),
ukernel.get_sr(),
group_size,
kai_datatype::kai_dt_bf16);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why bf16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the kleidi kernel keeps scales as bf16 to save space.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we asked Kleidi to do this :p

Comment on lines +17 to +19
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h>
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this should be behind TORCHAO_ENABLE_KLEIDI?

size_t n_groups = n * k / group_size;
auto weight_scales_bf16 = std::vector<uint16_t>(n_groups, 0);
for (size_t i = 0; i < n_groups; i++) {
assert(weight_zeros[i] == 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe assert weight_zeros is a nullptr or all of its entries are zero?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it be a nullptr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use nullptr to mean "no weight zeros", rather than create a thing of zeros.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking this only if !nullptr, else unused.

const float* bias,
float clamp_min,
float clamp_max) {
(void)bias; // unused - needs API fixing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That or it could be added in this wrapper after the ukernel.run_matmul call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I follow, can you elaborate? kleidi wants bias in weight packing not here.

@digantdesai digantdesai force-pushed the kleidi_prototype_gemv branch from b42fd69 to e68a9e2 Compare October 10, 2024 18:53
@facebook-github-bot
Copy link

@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@digantdesai digantdesai force-pushed the kleidi_prototype_gemv branch from e68a9e2 to f9a68f9 Compare October 10, 2024 20:30
@facebook-github-bot
Copy link

@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@digantdesai digantdesai force-pushed the kleidi_prototype_gemv branch from f9a68f9 to f6e22fb Compare October 11, 2024 00:06
@facebook-github-bot
Copy link

@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot facebook-github-bot merged commit db72dd1 into main Oct 11, 2024
18 of 19 checks passed
jainapurva pushed a commit that referenced this pull request Oct 15, 2024
Differential Revision: D64194844

Pull Request resolved: #997
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
As I was looking through documentations, In `/docs/Models.md`, I noticed one relative link `docs/GGUF.md` has a typo, which should be `GGUF.md`, so I changed it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants