From bb4ab4f3f4abe2043ad794db5cff25c5e9b7c928 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 10 Oct 2024 21:43:29 -0500 Subject: [PATCH] Kleidi 4b blockwise gemv prototype Differential Revision: D64194844 Pull Request resolved: https://github.com/pytorch/ao/pull/997 --- torchao/experimental/build_torchao_ops.sh | 9 +- .../kernels/cpu/aarch64/CMakeLists.txt | 23 ++ .../benchmarks/build_and_run_benchmarks.sh | 15 +- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 123 +++++++++ ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 124 +++++++++ .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 151 +++++++++++ .../kernels/cpu/aarch64/kleidi/pack.h | 118 +++++++++ .../kernels/cpu/aarch64/tests/CMakeLists.txt | 13 + .../cpu/aarch64/tests/build_and_run_tests.sh | 7 +- .../kernels/cpu/aarch64/tests/test_linear.cpp | 238 +++++++++++++++++- .../kernels/cpu/aarch64/tests/test_utils.h | 31 ++- 11 files changed, 842 insertions(+), 10 deletions(-) create mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h diff --git a/torchao/experimental/build_torchao_ops.sh b/torchao/experimental/build_torchao_ops.sh index 2cb7201588..1f13f36c77 100644 --- a/torchao/experimental/build_torchao_ops.sh +++ b/torchao/experimental/build_torchao_ops.sh @@ -1,16 +1,21 @@ -#!/bin/bash +#!/bin/bash -eu # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +if [[ $# -ne 1 ]]; then + echo "Usage: $0 "; + exit 1; +fi +TARGET="${1}" export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" export CMAKE_OUT=/tmp/cmake-out/torchao cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ - -DTORCHAO_OP_TARGET="$1" \ + -DTORCHAO_OP_TARGET="${TARGET}" \ -DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \ -DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \ -S . \ diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index 4f36945f8a..6541d7fdac 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -4,6 +4,22 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +include(FetchContent) + +# KleidiAI is an open-source library that provides optimized +# performance-critical routines, also known as micro-kernels, for artificial +# intelligence (AI) workloads tailored for ArmĀ® CPUs. +FetchContent_Declare(kleidiai + GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git + GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31) # same as xnnpack for now, TODO - revisit this + +FetchContent_MakeAvailable(kleidiai) + +# Disabled by default. Force enable if we are on a suitable system. +# TODO: Introduce ISA specific flags for i8mm. +CMAKE_DEPENDENT_OPTION(BUILD_KLEIDI "Download, build, and link against Arm KleidiAI library" + OFF "CMAKE_SYSTEM_PROCESSOR STREQUAL \"arm64\"" ON) + if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") add_library( torchao_kernels_aarch64 @@ -12,6 +28,13 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp ) + if (BUILD_KLEIDI) + # Temporarily exposing this to the parent scope until we wire + # this up properly from the top level + set(TORCHAO_ENABLE_KLEIDI ON PARENT_SCOPE) + message(STATUS "Building with Kleidi") + target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) + endif() endif() install( diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh index 08f8358365..e7fa9402e2 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh @@ -1,15 +1,24 @@ -#!/bin/bash +#!/bin/bash -eu # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} +set -eu +if [[ $# -ne 1 ]]; then + echo "Usage: $0 "; + exit 1; +fi + +BENCHMARK_TYPE="${1}" SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) + export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks + +# Build cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \ -B ${CMAKE_OUT} @@ -17,7 +26,7 @@ cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ cmake --build ${CMAKE_OUT} # Run -case "$1" in +case "${BENCHMARK_TYPE}" in quantization) ${CMAKE_OUT}/benchmark_quantization; ;; bitpacking) ${CMAKE_OUT}/benchmark_bitpacking; ;; linear) ${CMAKE_OUT}/benchmark_linear; ;; diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h new file mode 100644 index 0000000000..ded42df9db --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -0,0 +1,123 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { + +namespace neon_dotprod_1x4x32 { +const Ukernel get_ukernel() { + return Ukernel{ + .get_m_step = + kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_n_step = + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_mr = + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_nr = + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_kr = + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_sr = + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_dst_offset = + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_dst_size = + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .run_matmul = + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; +} + +size_t activation_data_size(int m, int k, int group_size) { + (void)group_size; // unused + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( + get_ukernel(), m, k); +} + +void prepare_activation_data( + void* activation_data, + int m, + int k, + int group_size, + const float* activations) { + (void)group_size; // unused + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( + get_ukernel(), activation_data, m, k, activations); +} + +size_t weight_data_size(int n, int k, int group_size) { + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( + get_ukernel(), n, k, group_size); +} + +void prepare_weight_data( + void* weight_data, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros) { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( + get_ukernel(), + weight_data, + n, + k, + group_size, + weight_qvals, + weight_scales, + weight_zeros); +} + +void kernel( + float32_t* output, + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + const float* bias, + float clamp_min, + float clamp_max) { + (void)bias; // TODO(T203756650) - unused - needs API fixing + assert(output_m_stride == n); + if (clamp_min == 0 && clamp_max == 0) { + clamp_min = std::numeric_limits::lowest(); + clamp_max = std::numeric_limits::max(); + } + + auto ukernel = get_ukernel(); + ukernel.run_matmul( + m, + n, + k, + group_size, + activation_data, + weight_data, + output, + /*dst_stride_row=*/n * sizeof(float), + /*dst_stride_col=*/sizeof(float), + clamp_min, + clamp_max); +} + +size_t get_preferred_alignement() { + return 16; +} +} // namespace neon_dotprod_1x4x32 +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h new file mode 100644 index 0000000000..116f25bd59 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -0,0 +1,124 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { +namespace neon_dotprod_1x8x32 { +const Ukernel get_ukernel() { + return Ukernel{ + .get_m_step = + kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_n_step = + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_mr = + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_nr = + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_kr = + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_sr = + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_dst_offset = + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_dst_size = + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .run_matmul = + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}; +} + +size_t activation_data_size(int m, int k, int group_size) { + (void) group_size; // unused + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k); +} + +void prepare_activation_data( + void* activation_data, + int m, + int k, + int group_size, + const float* activations) { + (void) group_size; // unused + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( + get_ukernel(), + activation_data, + m, + k, + activations); +} + +size_t weight_data_size(int n, int k, int group_size) { + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size); +} + +void prepare_weight_data( + void* weight_data, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros) { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( + get_ukernel(), + weight_data, + n, + k, + group_size, + weight_qvals, + weight_scales, + weight_zeros); +} + +void kernel( + float32_t* output, + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + const float* bias, + float clamp_min, + float clamp_max) { + (void) bias; // TODO(T203756650) - unused - needs API fixing + assert(output_m_stride == n); + if (clamp_min == 0 && clamp_max == 0) { + clamp_min = std::numeric_limits::lowest(); + clamp_max = std::numeric_limits::max(); + } + + auto ukernel = get_ukernel(); + ukernel.run_matmul( + m, + n, + k, + group_size, + activation_data, + weight_data, + output, + /*dst_stride_row=*/ n * sizeof(float), + /*dst_stride_col=*/ sizeof(float), + clamp_min, + clamp_max); +} + +size_t get_preferred_alignement() { + return 16; +} +} // namespace neon_dotprod_1x4x32 +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h new file mode 100644 index 0000000000..ae971f454f --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -0,0 +1,151 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { + +// Helper functions +// TODO: find a better place for these? + +size_t roundup(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +uint16_t get_bf16_from_float(float f) { + uint16_t bf16; +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + memcpy(&bf16, &f, sizeof(uint16_t)); +#else + const void* fp = reinterpret_cast( + reinterpret_cast(&f) + sizeof(float) - sizeof(uint16_t)); + memcpy(&bf16, fp, sizeof(uint16_t)); +#endif // __BYTE_ORDER__ + return bf16; +} + +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { + +using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; + +size_t activation_data_size(const Ukernel ukernel, int m, int k) { + auto lhs_packing = get_lhs_packing(); + return lhs_packing.get_lhs_packed_size( + m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); +} + +void prepare_activation_data( + const Ukernel ukernel, + void* activation_data, + int m, + int k, + const float* activations) { + auto lhs_pack = get_lhs_packing(); + + lhs_pack.run_lhs_pack( + m, + k, + ukernel.get_mr(), + ukernel.get_kr(), + ukernel.get_sr(), + /*m_index_start=*/0, + activations, + /*lhs_stride=*/k * sizeof(float), + activation_data); +} + +size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { + auto rhs_pack = get_rhs_packing(); + return rhs_pack.get_rhs_packed_size( + n, + k, + ukernel.get_nr(), + ukernel.get_kr(), + ukernel.get_sr(), + group_size, + kai_datatype::kai_dt_bf16); +} + +void prepare_weight_data( + const Ukernel ukernel, + void* weight_data, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros) { + // TODO(T204312268) - remove this constraint and pad when possible + assert(n % 2 == 0); + + assert(group_size % 32 == 0); + assert(k % group_size == 0); + + // TODO SIMDify this + size_t n_groups = n * k / group_size; + auto weight_scales_bf16 = std::vector(n_groups, 0); + + // We don't support weight zeros yet + if (weight_zeros != nullptr) { + for (size_t i = 0; i < n_groups; i++) { + assert(weight_zeros[i] == 0); + } + } + + for (size_t i = 0; i < n_groups; i++) { + weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]); + } + + // Prepack weights before packing + // TODO SIMDify this + auto packed_weight_qvals = std::vector(n * k / 2, 0); + uint8_t wzp = 8; + for (size_t i = 0; i < n * k; i += 2) { + const uint8_t low = static_cast(weight_qvals[i] + wzp); + const uint8_t high = static_cast(weight_qvals[i + 1] + wzp); + packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF)); + } + + // Parameters for packing + rhs_packing::qparams_t qparams{ + .lhs_zero_point = 1, + .rhs_zero_point = wzp, + .scale_dt = kai_datatype::kai_dt_bf16}; + + auto rhs_pack = get_rhs_packing(); + + rhs_pack.run_rhs_pack( + /*groups=*/1, + n, + k, + ukernel.get_nr(), + ukernel.get_kr(), + ukernel.get_sr(), + group_size, + /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), + /*rhs_stride=*/roundup(k, 2) / 2, + /*bias=*/nullptr, // TODO(T203756650) fix APIs to move bias here + /*scale=*/reinterpret_cast(weight_scales_bf16.data()), + /*scale_stride=*/sizeof(uint16_t) * (roundup(k, group_size) / group_size), + /*rhs_packed=*/weight_data, + /*extra_bytes=*/0, + /*qparams=*/&qparams); +} + +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h new file mode 100644 index 0000000000..692df73d55 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h @@ -0,0 +1,118 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { +// All the kernels in this namespace use following packing interface/routines. +// TODO: move these to Kleidi as interfaces? +typedef struct rhs_packing { + typedef struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params qparams_t; + typedef size_t (*get_rhs_offset_t)(size_t, size_t); + typedef size_t (*get_rhs_packed_stride_t)( + size_t, + size_t, + size_t, + size_t, + size_t, + enum kai_datatype); + typedef size_t (*get_rhs_packed_offset_t)( + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + enum kai_datatype); + typedef size_t (*get_rhs_packed_size_t)( + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + enum kai_datatype); + typedef void (*run_rhs_pack_t)( + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + const uint8_t*, + size_t, + const float*, + const void*, + size_t, + void*, + size_t, + const qparams_t*); + + get_rhs_offset_t get_rhs_offset; + get_rhs_packed_stride_t get_rhs_packed_stride; + get_rhs_packed_offset_t get_rhs_packed_offset; + get_rhs_packed_size_t get_rhs_packed_size; + run_rhs_pack_t run_rhs_pack; +} rhs_packing; + +// TODO add transpose variant i.e kxn +rhs_packing get_rhs_packing() { + return rhs_packing{ + .get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, + .get_rhs_packed_stride = + kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, + .get_rhs_packed_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, + .run_rhs_pack = kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0}; +} + +typedef struct lhs_packing { + typedef size_t (*get_lhs_m_step_t)(size_t); + typedef size_t (*get_lhs_offset_t)(size_t, size_t); + typedef size_t ( + *get_lhs_packed_offset_t)(size_t, size_t, size_t, size_t, size_t); + typedef size_t ( + *get_lhs_packed_size_t)(size_t, size_t, size_t, size_t, size_t); + typedef void (*run_lhs_pack_t)( + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + const float*, + size_t, + void*); + + get_lhs_m_step_t get_lhs_m_step; + get_lhs_offset_t get_lhs_offset; + get_lhs_packed_offset_t get_lhs_packed_offset; + get_lhs_packed_size_t get_lhs_packed_size; + run_lhs_pack_t run_lhs_pack; +} lhs_packing; + +lhs_packing get_lhs_packing() { + return lhs_packing{ + .get_lhs_m_step = kai_get_m_step_lhs_quant_pack_qai8dxp_f32, + .get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32, + .get_lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32, + .run_lhs_pack = kai_run_lhs_quant_pack_qai8dxp_f32}; +} + +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 8e281ed79e..c9799eadd9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -29,6 +29,18 @@ add_library( ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp ) +if(NOT TORCHAO_INCLUDE_DIRS) + set(TORCHAO_INCLUDE_DIRS ${TORCHAO_LIBRARIES}) +endif() + +add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +# The TORCHAO_ENABLE_KLEIDI cmake variable should be set by `torchao_kernels_aarch64" +# This is a temporary work around. +if(TORCHAO_ENABLE_KLEIDI) + add_compile_definitions(TORCHAO_ENABLE_KLEIDI) +endif() + enable_testing() add_executable(test_quantization test_quantization.cpp) @@ -61,6 +73,7 @@ target_link_libraries( PRIVATE GTest::gtest_main dep + torchao_kernels_aarch64 ) add_executable(test_valpacking test_valpacking.cpp) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index ce8861ac65..27c584ce20 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -1,16 +1,17 @@ -#!/bin/bash +#!/bin/bash -eu # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +set -eu SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests -B ${CMAKE_OUT} +cmake -DCMAKE_BUILD_TYPE=Debug -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests -B ${CMAKE_OUT} -cmake --build ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} # Run ${CMAKE_OUT}/test_quantization diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 47902be72c..571b91d476 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -7,11 +7,17 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include +#include + #include #include #include #include -#include + +#ifdef TORCHAO_ENABLE_KLEIDI +#include +#include +#endif float kTol = 0.0001; @@ -350,4 +356,234 @@ TEST( } } +#ifdef TORCHAO_ENABLE_KLEIDI +template +void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + int m, + int k, + int n, + int group_size) { + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + /*weight_nbit=*/4, + /*has_weight_zeros*/false, + has_bias, + has_clamp, + /*weight_scale_bf16_round_trip=*/true); + + using namespace torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; + + std::vector activation_data(activation_data_size(m, k, group_size)); + + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data(weight_data_size(n, k, group_size)); + + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data()); + + std::vector output(m * n); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*bias=*/test_case.bias.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + k_eq_gs_32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + large_k_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + even_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + m_clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + + +template +void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + int m, + int k, + int n, + int group_size) { + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + /*weight_nbit=*/4, + /*has_weight_zeros=*/false, + has_bias, + has_clamp, + /*round_weight_scales_to_bf16=*/true); + + using namespace torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; + + std::vector activation_data(activation_data_size(m, k, group_size)); + + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data(weight_data_size(n, k, group_size)); + + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data()); + + std::vector output(m * n); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*bias=*/test_case.bias.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + k_eq_gs_32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + large_k_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + even_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + m_clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); +} +#endif // TORCHAO_ENABLE_KLEIDI #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index c3dc431c08..e9f36e14ac 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -43,6 +43,28 @@ inline std::vector get_random_lowbit_vector(int size, int nbit) { return res; } +// TODO move these to a common utils +inline uint16_t +get_bf16_from_float(float f) { + uint16_t bf16; +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + memcpy(&bf16, &f, sizeof(uint16_t)); +#else + const void* fp = reinterpret_cast( + reinterpret_cast(&f) + sizeof(float) - sizeof(uint16_t)); + memcpy(&bf16, fp, sizeof(uint16_t)); +#endif // __BYTE_ORDER__ + return bf16; +} + +inline float +get_float_from_bf16(uint16_t bf16) { + float f; + const uint32_t i32 = (bf16 << 16); + memcpy(&f, &i32, sizeof(uint32_t)); + return f; +} + struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { int m; int k; @@ -135,7 +157,8 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { int weight_nbit, bool has_weight_zeros, bool has_bias, - bool has_clamp) { + bool has_clamp, + bool round_weight_scales_to_bf16=false) { // activations is m x k (stored in row-major) // weights is k x n (stored in column-major) @@ -198,6 +221,11 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { scale = torchao::quantization::get_scale(vmin, vmax, qmin, qmax); zero = 0; } + if (round_weight_scales_to_bf16) { + // weight scales are bf16 in the kernel + // so we need to round trip them to bf16 and back to float to match it. + scale = get_float_from_bf16(get_bf16_from_float(scale)); + } weight_scales[group_idx] = scale; weight_zeros[group_idx] = zero; @@ -225,6 +253,7 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { // Compute expected output std::vector expected_output(m * n); + for (int m_idx = 0; m_idx < m; m_idx++) { for (int n_idx = 0; n_idx < n; n_idx++) { float res = 0.0;