From 036b78278dd1bfcd31bd022dcb1d939a52709ac5 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 1 Oct 2024 20:21:32 -0500 Subject: [PATCH 01/19] [experimental] simple script UX fixes --- torchao/experimental/build_torchao_ops.sh | 9 +++++++-- .../benchmarks/build_and_run_benchmarks.sh | 15 ++++++++++++--- .../cpu/aarch64/tests/build_and_run_tests.sh | 3 ++- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/torchao/experimental/build_torchao_ops.sh b/torchao/experimental/build_torchao_ops.sh index 2cb7201588..1f13f36c77 100644 --- a/torchao/experimental/build_torchao_ops.sh +++ b/torchao/experimental/build_torchao_ops.sh @@ -1,16 +1,21 @@ -#!/bin/bash +#!/bin/bash -eu # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +if [[ $# -ne 1 ]]; then + echo "Usage: $0 "; + exit 1; +fi +TARGET="${1}" export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" export CMAKE_OUT=/tmp/cmake-out/torchao cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ - -DTORCHAO_OP_TARGET="$1" \ + -DTORCHAO_OP_TARGET="${TARGET}" \ -DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \ -DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \ -S . \ diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh index 08f8358365..e7fa9402e2 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh @@ -1,15 +1,24 @@ -#!/bin/bash +#!/bin/bash -eu # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} +set -eu +if [[ $# -ne 1 ]]; then + echo "Usage: $0 "; + exit 1; +fi + +BENCHMARK_TYPE="${1}" SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) + export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks + +# Build cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \ -B ${CMAKE_OUT} @@ -17,7 +26,7 @@ cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ cmake --build ${CMAKE_OUT} # Run -case "$1" in +case "${BENCHMARK_TYPE}" in quantization) ${CMAKE_OUT}/benchmark_quantization; ;; bitpacking) ${CMAKE_OUT}/benchmark_bitpacking; ;; linear) ${CMAKE_OUT}/benchmark_linear; ;; diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index ce8861ac65..98ce559914 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -1,10 +1,11 @@ -#!/bin/bash +#!/bin/bash -eu # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +set -eu SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests From c4b9f1e72ee9d8dfab6094904a35021de97e833f Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 1 Oct 2024 20:21:04 -0500 Subject: [PATCH 02/19] [experimental][kleidi] Add build support --- .../kernels/cpu/aarch64/CMakeLists.txt | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index 4f36945f8a..ebc9716832 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -4,6 +4,24 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +include(FetchContent) + +# KleidiAI is an open-source library that provides optimized +# performance-critical routines, also known as micro-kernels, for artificial +# intelligence (AI) workloads tailored for ArmĀ® CPUs. +FetchContent_Declare(kleidiai + GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git + GIT_TAG main) # TODO: set a pin + +FetchContent_MakeAvailable(kleidiai) + +# Disabled by default. Force enable if we are on a suitable system. +# TODO: Introduce ISA specific flags for i8mm. +CMAKE_DEPENDENT_OPTION(BUILD_KLEIDI "Download, build, and link against Arm KleidiAI library" + OFF "CMAKE_SYSTEM_PROCESSOR STREQUAL \"arm64\"" ON) + +add_compile_definitions("TORCHAO_ENABLE_KLEIDI=$") + if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") add_library( torchao_kernels_aarch64 @@ -12,6 +30,10 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp ) + if (BUILD_KLEIDI) + message(STATUS "Building with Kleidi") + target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) + endif() endif() install( From 4a85c4db13589e23352724bc6b9fe645f96a7007 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 1 Oct 2024 21:44:51 -0500 Subject: [PATCH 03/19] [experimental][kleidi] Add uConfig support for qb4w 1x4x32 neon dotprod --- .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 107 ++++++++++++++++++ .../kernels/cpu/aarch64/kleidi/pack.h | 65 +++++++++++ 2 files changed, 172 insertions(+) create mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h new file mode 100644 index 0000000000..4d785fb2b3 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -0,0 +1,107 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { + namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { + + using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; + + namespace neon_dotprod_1x4x32 { + ukernel get_ukernel() { + return ukernel { + .get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_n_step = kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_kr = kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_sr = kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_lhs_packed_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod + }; + } + + int activation_data_size(int m, int k, int group_size) { + auto ukernel = get_ukernel(); + auto lhs_packing = get_lhs_packing(); + return lhs_packing.get_lhs_packed_size(m, k, group_size, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); + } + + void prepare_activation_data( + void* activation_data, + // Inputs + int m, + int k, + // Ignored if has_weight_zeros = false + int group_size, + const float* activations) { + auto ukernel = get_ukernel(); + auto lhs_pack = get_lhs_packing(); + lhs_pack.run_lhs_pack(m, k, group_size, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr(), /*m_index_start=*/0, + activations, /*lhs_stride=*/ k*sizeof(float), activation_data); + } + + int weight_data_size(int n, int k, int group_size) { + auto ukernel = get_ukernel(); + auto rhs_pack = get_rhs_packing(); + return rhs_pack.get_rhs_packed_size(n, k, ukernel.get_nr(), ukernel.get_kr(), group_size); + } + + void prepare_weight_data( + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros) { + if (weight_zeros) { + // TODO check all zeros + assert (weight_zeros[0] == 8); + } + auto ukernel = get_ukernel(); + auto rhs_pack = get_rhs_packing(); + rhs_packing::qparams_t qparams{1, 8}; + // @nocommit - Unsigned hack, add a naive packing routine + rhs_pack.run_rhs_pack(/*groups=*/1, n, k, ukernel.get_nr(), ukernel.get_kr(), ukernel.get_sr(), + group_size, reinterpret_cast(weight_qvals), /*bias=*/nullptr, weight_data, /*extra_bytes=*/0, &qparams); + } + + void kernel( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Not applied if nullptr + const float* bias, + // Ignored if has_clamp = false + float clamp_min, + float clamp_max) { + auto ukernel = get_ukernel(); + ukernel.run_matmul(m, n, k, group_size, activation_data, weight_data, output, output_m_stride, /*dst_stride_col=*/1, clamp_min, clamp_max); + } + + size_t get_alignement() { + return 16; + } + } // namespace neon_dotprod_1x4x32 + } // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h new file mode 100644 index 0000000000..808c321a0f --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h @@ -0,0 +1,65 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { + namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { + // All the kernels in this namespace use following packing interface/routines. + // TODO: move these to Kleidi as interfaces? + typedef struct rhs_packing { + typedef struct kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0_params qparams_t; + typedef size_t (*get_rhs_offset_t)(size_t, size_t); + typedef size_t (*get_rhs_packed_offset_t)(size_t, size_t, size_t, size_t, size_t); + typedef size_t (*get_rhs_packed_size_t)(size_t, size_t, size_t, size_t, size_t); + typedef void (*run_rhs_pack_t)(size_t, size_t, size_t, size_t, size_t, size_t, size_t, const uint8_t*, const float*, void*, size_t, const qparams_t*); + + get_rhs_offset_t get_rhs_offset; + get_rhs_packed_offset_t get_rhs_packed_offset; + get_rhs_packed_size_t get_rhs_packed_size; + run_rhs_pack_t run_rhs_pack; + } rhs_packing; + + typedef struct lhs_packing { + typedef size_t (*get_lhs_m_step_t)(size_t); + typedef size_t (*get_lhs_offset_t)(size_t, size_t); + typedef size_t (*get_lhs_packed_offset_t)(size_t, size_t, size_t, size_t, size_t, size_t); + typedef size_t (*get_lhs_packed_size_t)(size_t, size_t, size_t, size_t, size_t, size_t); + typedef void (*run_lhs_pack_t)(size_t, size_t, size_t, size_t, size_t, size_t, size_t, const float*, size_t, void*); + + get_lhs_m_step_t get_lhs_m_step; + get_lhs_offset_t get_lhs_offset; + get_lhs_packed_offset_t get_lhs_packed_offset; + get_lhs_packed_size_t get_lhs_packed_size; + run_lhs_pack_t run_lhs_pack; + } lhs_packing; + + // TODO add transpose variant i.e kxn + rhs_packing get_rhs_packing() { + return rhs_packing { + .get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + .get_rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + .get_rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + .run_rhs_pack = kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0 + }; + } + + lhs_packing get_lhs_packing() { + return lhs_packing { + .get_lhs_m_step = kai_get_m_step_lhs_quant_pack_qsi8d32p_f32, + .get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + .get_lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + .get_lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + .run_lhs_pack = kai_run_lhs_quant_pack_qsi8d32p_f32 + }; + } + + } // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi From 49afa4a6d7777e20c4f90f3f4e7c0fa417d2add4 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 1 Oct 2024 21:41:07 -0500 Subject: [PATCH 04/19] [experimental][kleidi] Add a basic test - compiles --- .../kernels/cpu/aarch64/tests/CMakeLists.txt | 6 ++ .../kernels/cpu/aarch64/tests/test_linear.cpp | 72 +++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 8e281ed79e..29f38b57fa 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -29,6 +29,11 @@ add_library( ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp ) +if(NOT TORCHAO_INCLUDE_DIRS) + set(TORCHAO_INCLUDE_DIRS ${TORCHAO_LIBRARIES}) +endif() +add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + enable_testing() add_executable(test_quantization test_quantization.cpp) @@ -61,6 +66,7 @@ target_link_libraries( PRIVATE GTest::gtest_main dep + torchao_kernels_aarch64 ) add_executable(test_valpacking test_valpacking.cpp) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 47902be72c..d19b711133 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include float kTol = 0.0001; @@ -350,4 +351,75 @@ TEST( } } +template +void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + int m, + int k, + int n, + int group_size) { + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp); + + using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; + + std::vector activation_data( + activation_data_size(m, k, group_size)); + + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data( + weight_data_size(n, k, group_size)); + + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data()); + + std::vector output(m * n); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*bias=*/test_case.bias.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + only_supported) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/16, /*k=*/64, /*n=*/16, /*group_size=*/32); +} + #endif // defined(__aarch64__) || defined(__ARM_NEON) From 569c069d5735cec5b5e86ef2a366ffb82718f426 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Mon, 7 Oct 2024 20:37:17 -0500 Subject: [PATCH 05/19] [experimental][kleidi] Pin kleidiai repo --- torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index ebc9716832..da99a55ed1 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -11,7 +11,7 @@ include(FetchContent) # intelligence (AI) workloads tailored for ArmĀ® CPUs. FetchContent_Declare(kleidiai GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git - GIT_TAG main) # TODO: set a pin + GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31) # same as xnnpack for now, TODO - revisit this FetchContent_MakeAvailable(kleidiai) From fd1423f55992c636afcae978dca12619e90e7e85 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Mon, 7 Oct 2024 20:46:14 -0500 Subject: [PATCH 06/19] [experimental][kleidi] Clean up pack.h --- .../kernels/cpu/aarch64/kleidi/pack.h | 147 ++++++++++++------ 1 file changed, 100 insertions(+), 47 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h index 808c321a0f..692df73d55 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h @@ -7,59 +7,112 @@ #pragma once #include -#include -#include + +#include +#include +#include namespace torchao::kernels::cpu::aarch64::kleidi { - namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { - // All the kernels in this namespace use following packing interface/routines. - // TODO: move these to Kleidi as interfaces? - typedef struct rhs_packing { - typedef struct kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0_params qparams_t; - typedef size_t (*get_rhs_offset_t)(size_t, size_t); - typedef size_t (*get_rhs_packed_offset_t)(size_t, size_t, size_t, size_t, size_t); - typedef size_t (*get_rhs_packed_size_t)(size_t, size_t, size_t, size_t, size_t); - typedef void (*run_rhs_pack_t)(size_t, size_t, size_t, size_t, size_t, size_t, size_t, const uint8_t*, const float*, void*, size_t, const qparams_t*); +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { +// All the kernels in this namespace use following packing interface/routines. +// TODO: move these to Kleidi as interfaces? +typedef struct rhs_packing { + typedef struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params qparams_t; + typedef size_t (*get_rhs_offset_t)(size_t, size_t); + typedef size_t (*get_rhs_packed_stride_t)( + size_t, + size_t, + size_t, + size_t, + size_t, + enum kai_datatype); + typedef size_t (*get_rhs_packed_offset_t)( + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + enum kai_datatype); + typedef size_t (*get_rhs_packed_size_t)( + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + enum kai_datatype); + typedef void (*run_rhs_pack_t)( + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + const uint8_t*, + size_t, + const float*, + const void*, + size_t, + void*, + size_t, + const qparams_t*); - get_rhs_offset_t get_rhs_offset; - get_rhs_packed_offset_t get_rhs_packed_offset; - get_rhs_packed_size_t get_rhs_packed_size; - run_rhs_pack_t run_rhs_pack; - } rhs_packing; + get_rhs_offset_t get_rhs_offset; + get_rhs_packed_stride_t get_rhs_packed_stride; + get_rhs_packed_offset_t get_rhs_packed_offset; + get_rhs_packed_size_t get_rhs_packed_size; + run_rhs_pack_t run_rhs_pack; +} rhs_packing; - typedef struct lhs_packing { - typedef size_t (*get_lhs_m_step_t)(size_t); - typedef size_t (*get_lhs_offset_t)(size_t, size_t); - typedef size_t (*get_lhs_packed_offset_t)(size_t, size_t, size_t, size_t, size_t, size_t); - typedef size_t (*get_lhs_packed_size_t)(size_t, size_t, size_t, size_t, size_t, size_t); - typedef void (*run_lhs_pack_t)(size_t, size_t, size_t, size_t, size_t, size_t, size_t, const float*, size_t, void*); +// TODO add transpose variant i.e kxn +rhs_packing get_rhs_packing() { + return rhs_packing{ + .get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, + .get_rhs_packed_stride = + kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, + .get_rhs_packed_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, + .run_rhs_pack = kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0}; +} - get_lhs_m_step_t get_lhs_m_step; - get_lhs_offset_t get_lhs_offset; - get_lhs_packed_offset_t get_lhs_packed_offset; - get_lhs_packed_size_t get_lhs_packed_size; - run_lhs_pack_t run_lhs_pack; - } lhs_packing; +typedef struct lhs_packing { + typedef size_t (*get_lhs_m_step_t)(size_t); + typedef size_t (*get_lhs_offset_t)(size_t, size_t); + typedef size_t ( + *get_lhs_packed_offset_t)(size_t, size_t, size_t, size_t, size_t); + typedef size_t ( + *get_lhs_packed_size_t)(size_t, size_t, size_t, size_t, size_t); + typedef void (*run_lhs_pack_t)( + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + const float*, + size_t, + void*); - // TODO add transpose variant i.e kxn - rhs_packing get_rhs_packing() { - return rhs_packing { - .get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - .get_rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - .get_rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - .run_rhs_pack = kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0 - }; - } + get_lhs_m_step_t get_lhs_m_step; + get_lhs_offset_t get_lhs_offset; + get_lhs_packed_offset_t get_lhs_packed_offset; + get_lhs_packed_size_t get_lhs_packed_size; + run_lhs_pack_t run_lhs_pack; +} lhs_packing; - lhs_packing get_lhs_packing() { - return lhs_packing { - .get_lhs_m_step = kai_get_m_step_lhs_quant_pack_qsi8d32p_f32, - .get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - .get_lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - .get_lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - .run_lhs_pack = kai_run_lhs_quant_pack_qsi8d32p_f32 - }; - } +lhs_packing get_lhs_packing() { + return lhs_packing{ + .get_lhs_m_step = kai_get_m_step_lhs_quant_pack_qai8dxp_f32, + .get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32, + .get_lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32, + .run_lhs_pack = kai_run_lhs_quant_pack_qai8dxp_f32}; +} - } // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p } // namespace torchao::kernels::cpu::aarch64::kleidi From c323fb13393c32fa3ec95c3a0a216cffc91613cc Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Mon, 7 Oct 2024 20:46:46 -0500 Subject: [PATCH 07/19] [experimental][kleidi] Refactor interface header --- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 214 ++++++++++++++++++ .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 97 +------- 2 files changed, 218 insertions(+), 93 deletions(-) create mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h new file mode 100644 index 0000000000..fb686d55c2 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -0,0 +1,214 @@ +// namespace example +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { + +using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; + +namespace neon_dotprod_1x4x32 { +ukernel get_ukernel() { + return ukernel{ + .get_m_step = + kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_n_step = + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_mr = + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_nr = + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_kr = + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_sr = + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_dst_offset = + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .get_dst_size = + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + .run_matmul = + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; +} + +size_t roundup(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +int activation_data_size(int m, int k, int group_size) { + auto ukernel = get_ukernel(); + auto lhs_packing = get_lhs_packing(); + return lhs_packing.get_lhs_packed_size( + m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); +} + +void prepare_activation_data( + void* activation_data, + // Inputs + int m, + int k, + // Ignored if has_weight_zeros = false + int group_size, + const float* activations) { + auto ukernel = get_ukernel(); + auto lhs_pack = get_lhs_packing(); + + lhs_pack.run_lhs_pack( + m, + k, + ukernel.get_mr(), + ukernel.get_kr(), + ukernel.get_sr(), + /*m_index_start=*/0, + activations, + /*lhs_stride=*/k * sizeof(float), + activation_data); +} + +int weight_data_size(int n, int k, int group_size) { + auto ukernel = get_ukernel(); + auto rhs_pack = get_rhs_packing(); + return rhs_pack.get_rhs_packed_size( + n, + k, + ukernel.get_nr(), + ukernel.get_kr(), + ukernel.get_sr(), + group_size, + kai_datatype::kai_dt_bf16); +} + +inline uint16_t get_bf16_from_float(float f) { + uint16_t bf16; +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + memcpy(&bf16, &f, sizeof(uint16_t)); +#else + const void* fp = reinterpret_cast( + reinterpret_cast(&f) + sizeof(float) - sizeof(uint16_t)); + memcpy(&bf16, fp, sizeof(uint16_t)); +#endif // __BYTE_ORDER__ + return bf16; +} + +// TODO: move most of these functions in the parent namespace and take in +// ukernel as a parameter +void prepare_weight_data( + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros) { + // TODO - remove this constraint and pad when possible + assert(n % 2 == 0); + + assert(group_size % 32 == 0); + assert(k % group_size == 0); + + // Convert scales to bf16 + // TODO SIMDify this + size_t n_groups = n * k / group_size; + auto weight_scales_bf16 = std::vector(n_groups, 0); + for (size_t i = 0; i < n_groups; i++) { + assert(weight_zeros[i] == 0); + weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]); + } + + // Prepack weights before packing + // TODO SIMDify this + auto packed_weight_qvals = std::vector(n * k / 2, 0); + uint8_t wzp = 8; + for (size_t i = 0; i < n * k; i += 2) { + const uint8_t low = static_cast(weight_qvals[i] + wzp); + const uint8_t high = static_cast(weight_qvals[i+1] + wzp); + packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF)); + } + + // Parameters for packing + rhs_packing::qparams_t qparams{ + .lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16}; + + auto ukernel = get_ukernel(); + auto rhs_pack = get_rhs_packing(); + + rhs_pack.run_rhs_pack( + /*groups=*/1, + n, + k, + ukernel.get_nr(), + ukernel.get_kr(), + ukernel.get_sr(), + group_size, + /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), + /*rhs_stride=*/roundup(k, 2) / 2, + /*bias=*/nullptr, // TODO fix APIs to move bias here + /*scale=*/reinterpret_cast(weight_scales_bf16.data()), + /*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size), + /*rhs_packed=*/weight_data, + /*extra_bytes=*/0, + /*qparams=*/&qparams); +} + +void kernel( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Not applied if nullptr + const float* bias, + // zeros if has_clamp = false + float clamp_min, + float clamp_max) { + assert(output_m_stride == n); + if (clamp_min == clamp_max && clamp_min == 0) { + clamp_min = std::numeric_limits::lowest(); + clamp_max = std::numeric_limits::max(); + } + auto ukernel = get_ukernel(); + ukernel.run_matmul( + m, + n, + k, + group_size, + activation_data, + weight_data, + output, + /*dst_stride_row=*/n * sizeof(float), + /*dst_stride_col=*/sizeof(float), + clamp_min, + clamp_max); +} + +size_t get_alignement() { + return 16; +} +} // namespace neon_dotprod_1x4x32 +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 4d785fb2b3..e4b28c93bf 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -1,3 +1,4 @@ +// namespace example // Copyright (c) Meta Platforms, Inc. and affiliates. // All rights reserved. // @@ -6,102 +7,12 @@ #pragma once -#include #include -#include namespace torchao::kernels::cpu::aarch64::kleidi { - namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { - using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; +using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; - namespace neon_dotprod_1x4x32 { - ukernel get_ukernel() { - return ukernel { - .get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_n_step = kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_kr = kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_sr = kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_lhs_packed_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod - }; - } - - int activation_data_size(int m, int k, int group_size) { - auto ukernel = get_ukernel(); - auto lhs_packing = get_lhs_packing(); - return lhs_packing.get_lhs_packed_size(m, k, group_size, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); - } - - void prepare_activation_data( - void* activation_data, - // Inputs - int m, - int k, - // Ignored if has_weight_zeros = false - int group_size, - const float* activations) { - auto ukernel = get_ukernel(); - auto lhs_pack = get_lhs_packing(); - lhs_pack.run_lhs_pack(m, k, group_size, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr(), /*m_index_start=*/0, - activations, /*lhs_stride=*/ k*sizeof(float), activation_data); - } - - int weight_data_size(int n, int k, int group_size) { - auto ukernel = get_ukernel(); - auto rhs_pack = get_rhs_packing(); - return rhs_pack.get_rhs_packed_size(n, k, ukernel.get_nr(), ukernel.get_kr(), group_size); - } - - void prepare_weight_data( - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros) { - if (weight_zeros) { - // TODO check all zeros - assert (weight_zeros[0] == 8); - } - auto ukernel = get_ukernel(); - auto rhs_pack = get_rhs_packing(); - rhs_packing::qparams_t qparams{1, 8}; - // @nocommit - Unsigned hack, add a naive packing routine - rhs_pack.run_rhs_pack(/*groups=*/1, n, k, ukernel.get_nr(), ukernel.get_kr(), ukernel.get_sr(), - group_size, reinterpret_cast(weight_qvals), /*bias=*/nullptr, weight_data, /*extra_bytes=*/0, &qparams); - } - - void kernel( - // Outputs - float32_t* output, - // Inputs - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - // Not applied if nullptr - const float* bias, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max) { - auto ukernel = get_ukernel(); - ukernel.run_matmul(m, n, k, group_size, activation_data, weight_data, output, output_m_stride, /*dst_stride_col=*/1, clamp_min, clamp_max); - } - - size_t get_alignement() { - return 16; - } - } // namespace neon_dotprod_1x4x32 - } // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p } // namespace torchao::kernels::cpu::aarch64::kleidi From 8aa27c4949c22156b3e29d1c351f33f99357351a Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Mon, 7 Oct 2024 20:47:18 -0500 Subject: [PATCH 08/19] [experimental][kleidi] Improve unit-tests --- .../kernels/cpu/aarch64/tests/test_linear.cpp | 41 +++++++++-- .../kernels/cpu/aarch64/tests/test_utils.h | 68 ++++++++++++++++++- 2 files changed, 104 insertions(+), 5 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index d19b711133..90fee4ffd3 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include float kTol = 0.0001; @@ -366,7 +366,8 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( weight_nbit, has_weight_zeros, has_bias, - has_clamp); + has_clamp, + /*weight_scale_bf16_round_trip=*/true); using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; @@ -413,13 +414,45 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( TEST( test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - only_supported) { + k_eq_gs_32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + large_k_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( - /*m=*/16, /*k=*/64, /*n=*/16, /*group_size=*/32); + /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); } +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + even_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index c3dc431c08..4316f8f2d6 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -43,6 +43,26 @@ inline std::vector get_random_lowbit_vector(int size, int nbit) { return res; } +// TODO move these to a common utils +uint16_t get_bf16_from_float(float f) { + uint16_t bf16; +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + memcpy(&bf16, &f, sizeof(uint16_t)); +#else + const void* fp = reinterpret_cast( + reinterpret_cast(&f) + sizeof(float) - sizeof(uint16_t)); + memcpy(&bf16, fp, sizeof(uint16_t)); +#endif // __BYTE_ORDER__ + return bf16; +} + +float get_float_from_bf16(uint16_t bf16) { + float f; + const uint32_t i32 = (bf16 << 16); + memcpy(&f, &i32, sizeof(uint32_t)); + return f; +} + struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { int m; int k; @@ -135,7 +155,8 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { int weight_nbit, bool has_weight_zeros, bool has_bias, - bool has_clamp) { + bool has_clamp, + bool weight_scale_bf16_round_trip=false) { // activations is m x k (stored in row-major) // weights is k x n (stored in column-major) @@ -198,6 +219,11 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { scale = torchao::quantization::get_scale(vmin, vmax, qmin, qmax); zero = 0; } + if (weight_scale_bf16_round_trip) { + // weight scales are bf16 in the kernel + // so we need to round trip them to bf16 and back to float to match it. + scale = get_float_from_bf16(get_bf16_from_float(scale)); + } weight_scales[group_idx] = scale; weight_zeros[group_idx] = zero; @@ -209,6 +235,7 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { zero, qmin, qmax); + // std::fill(weight_qvals.begin(), weight_qvals.end(), -7); } std::vector bias(m, 0.0); @@ -225,6 +252,7 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { // Compute expected output std::vector expected_output(m * n); + for (int m_idx = 0; m_idx < m; m_idx++) { for (int n_idx = 0; n_idx < n; n_idx++) { float res = 0.0; @@ -249,6 +277,44 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { } } +#if 0 // Alternate reference implementation for debugging. + auto num_groups = k / weight_group_size; + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + int32_t result_idx = m_idx * n + n_idx; + float weights_fsum = 0.0; + for (int g_idx = 0; g_idx < num_groups; g_idx++) { + int32_t weights_qsum = 0; + int32_t acc_i32 = 0; + for (int k_idx = 0; k_idx < weight_group_size; k_idx++) { + const int32_t activation_idx = m_idx * k + g_idx * weight_group_size + k_idx; + const int32_t weight_idx = n_idx * k + g_idx * weight_group_size + k_idx; + + const int32_t weight_qval = weight_qvals[weight_idx]; + const int32_t activation_qval = activation_qvals[activation_idx]; + + weights_qsum += weight_qval; + acc_i32 += weight_qval * activation_qval; + } + // For each group, we have a weight scale + const int32_t weight_scale_idx = n_idx * num_groups + g_idx; + const float weight_scale = weight_scales[weight_scale_idx]; // already rounded trip to bf16 + expected_output[result_idx] += (float) acc_i32 * weight_scales[weight_scale_idx]; + weights_fsum += weights_qsum * weight_scale; + } + // For each output channel, we have an activation scale + const int32_t activation_zero_point = activation_zeros[m_idx]; + const float activation_scale = activation_scales[m_idx]; + expected_output[result_idx] -= activation_zero_point * weights_fsum; + expected_output[result_idx] *= activation_scale; + expected_output[result_idx] += bias[m_idx]; + if (has_clamp) { + expected_output[result_idx] = std::min(std::max(expected_output[result_idx], clamp_min), clamp_max); + } + } + } +#endif + // Return test case return channelwise_8bit_activation_groupwise_lowbit_weight_test_case( m, From 44ca4defb5decbad446fc8ebf4537f552bc742e7 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Mon, 7 Oct 2024 21:46:26 -0500 Subject: [PATCH 09/19] [experimental][kleidi] move common functions to interface --- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 146 ++++-------------- .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 127 ++++++++++++++- 2 files changed, 155 insertions(+), 118 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index fb686d55c2..381ba9489d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -6,26 +6,17 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include -#include -#include -#include -#include -#include - -#include #include -#include + +#include namespace torchao::kernels::cpu::aarch64::kleidi { namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { -using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; - namespace neon_dotprod_1x4x32 { -ukernel get_ukernel() { - return ukernel{ +const Ukernel get_ukernel() { + return Ukernel{ .get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, .get_n_step = @@ -50,130 +41,51 @@ ukernel get_ukernel() { kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; } -size_t roundup(size_t a, size_t b) { - return ((a + b - 1) / b) * b; -} - int activation_data_size(int m, int k, int group_size) { - auto ukernel = get_ukernel(); - auto lhs_packing = get_lhs_packing(); - return lhs_packing.get_lhs_packed_size( - m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); + (void) group_size; // unused + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k); } void prepare_activation_data( void* activation_data, - // Inputs int m, int k, - // Ignored if has_weight_zeros = false int group_size, const float* activations) { - auto ukernel = get_ukernel(); - auto lhs_pack = get_lhs_packing(); - - lhs_pack.run_lhs_pack( + (void) group_size; // unused + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( + get_ukernel(), + activation_data, m, k, - ukernel.get_mr(), - ukernel.get_kr(), - ukernel.get_sr(), - /*m_index_start=*/0, - activations, - /*lhs_stride=*/k * sizeof(float), - activation_data); + activations); } int weight_data_size(int n, int k, int group_size) { - auto ukernel = get_ukernel(); - auto rhs_pack = get_rhs_packing(); - return rhs_pack.get_rhs_packed_size( - n, - k, - ukernel.get_nr(), - ukernel.get_kr(), - ukernel.get_sr(), - group_size, - kai_datatype::kai_dt_bf16); -} - -inline uint16_t get_bf16_from_float(float f) { - uint16_t bf16; -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - memcpy(&bf16, &f, sizeof(uint16_t)); -#else - const void* fp = reinterpret_cast( - reinterpret_cast(&f) + sizeof(float) - sizeof(uint16_t)); - memcpy(&bf16, fp, sizeof(uint16_t)); -#endif // __BYTE_ORDER__ - return bf16; + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size); } -// TODO: move most of these functions in the parent namespace and take in -// ukernel as a parameter void prepare_weight_data( void* weight_data, - // Inputs int n, int k, int group_size, const int8_t* weight_qvals, const float* weight_scales, const int8_t* weight_zeros) { - // TODO - remove this constraint and pad when possible - assert(n % 2 == 0); - - assert(group_size % 32 == 0); - assert(k % group_size == 0); - - // Convert scales to bf16 - // TODO SIMDify this - size_t n_groups = n * k / group_size; - auto weight_scales_bf16 = std::vector(n_groups, 0); - for (size_t i = 0; i < n_groups; i++) { - assert(weight_zeros[i] == 0); - weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]); - } - - // Prepack weights before packing - // TODO SIMDify this - auto packed_weight_qvals = std::vector(n * k / 2, 0); - uint8_t wzp = 8; - for (size_t i = 0; i < n * k; i += 2) { - const uint8_t low = static_cast(weight_qvals[i] + wzp); - const uint8_t high = static_cast(weight_qvals[i+1] + wzp); - packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF)); - } - - // Parameters for packing - rhs_packing::qparams_t qparams{ - .lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16}; - - auto ukernel = get_ukernel(); - auto rhs_pack = get_rhs_packing(); - - rhs_pack.run_rhs_pack( - /*groups=*/1, + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( + get_ukernel(), + weight_data, n, k, - ukernel.get_nr(), - ukernel.get_kr(), - ukernel.get_sr(), group_size, - /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), - /*rhs_stride=*/roundup(k, 2) / 2, - /*bias=*/nullptr, // TODO fix APIs to move bias here - /*scale=*/reinterpret_cast(weight_scales_bf16.data()), - /*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size), - /*rhs_packed=*/weight_data, - /*extra_bytes=*/0, - /*qparams=*/&qparams); + weight_qvals, + weight_scales, + weight_zeros); } void kernel( - // Outputs float32_t* output, - // Inputs int output_m_stride, int m, int n, @@ -181,18 +93,18 @@ void kernel( int group_size, const void* weight_data, const void* activation_data, - // Not applied if nullptr const float* bias, - // zeros if has_clamp = false float clamp_min, float clamp_max) { - assert(output_m_stride == n); - if (clamp_min == clamp_max && clamp_min == 0) { - clamp_min = std::numeric_limits::lowest(); - clamp_max = std::numeric_limits::max(); - } - auto ukernel = get_ukernel(); - ukernel.run_matmul( + (void) bias; // unused - needs API fixing + assert(output_m_stride == n); + if (clamp_min == 0 && clamp_max == 0) { + clamp_min = std::numeric_limits::lowest(); + clamp_max = std::numeric_limits::max(); + } + + auto ukernel = get_ukernel(); + ukernel.run_matmul( m, n, k, @@ -200,8 +112,8 @@ void kernel( activation_data, weight_data, output, - /*dst_stride_row=*/n * sizeof(float), - /*dst_stride_col=*/sizeof(float), + /*dst_stride_row=*/ n * sizeof(float), + /*dst_stride_col=*/ sizeof(float), clamp_min, clamp_max); } diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index e4b28c93bf..71b7857e4d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -7,12 +7,137 @@ #pragma once +#include +#include +#include +#include +#include +#include + +#include #include +#include + namespace torchao::kernels::cpu::aarch64::kleidi { + +// Helper functions +// TODO: find a better place for these? + +size_t roundup(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +uint16_t get_bf16_from_float(float f) { + uint16_t bf16; +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + memcpy(&bf16, &f, sizeof(uint16_t)); +#else + const void* fp = reinterpret_cast( + reinterpret_cast(&f) + sizeof(float) - sizeof(uint16_t)); + memcpy(&bf16, fp, sizeof(uint16_t)); +#endif // __BYTE_ORDER__ + return bf16; +} + namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { -using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; +using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; + +int activation_data_size(const Ukernel ukernel, int m, int k) { + auto lhs_packing = get_lhs_packing(); + return lhs_packing.get_lhs_packed_size( + m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); +} + +void prepare_activation_data( + const Ukernel ukernel, + void* activation_data, + int m, + int k, + const float* activations) { + auto lhs_pack = get_lhs_packing(); + + lhs_pack.run_lhs_pack( + m, + k, + ukernel.get_mr(), + ukernel.get_kr(), + ukernel.get_sr(), + /*m_index_start=*/0, + activations, + /*lhs_stride=*/k * sizeof(float), + activation_data); +} + +int weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { + auto rhs_pack = get_rhs_packing(); + return rhs_pack.get_rhs_packed_size( + n, + k, + ukernel.get_nr(), + ukernel.get_kr(), + ukernel.get_sr(), + group_size, + kai_datatype::kai_dt_bf16); +} + +void prepare_weight_data( + const Ukernel ukernel, + void* weight_data, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros) { + // TODO - remove this constraint and pad when possible + assert(n % 2 == 0); + + assert(group_size % 32 == 0); + assert(k % group_size == 0); + + // TODO SIMDify this + size_t n_groups = n * k / group_size; + auto weight_scales_bf16 = std::vector(n_groups, 0); + for (size_t i = 0; i < n_groups; i++) { + assert(weight_zeros[i] == 0); + weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]); + } + + // Prepack weights before packing + // TODO SIMDify this + auto packed_weight_qvals = std::vector(n * k / 2, 0); + uint8_t wzp = 8; + for (size_t i = 0; i < n * k; i += 2) { + const uint8_t low = static_cast(weight_qvals[i] + wzp); + const uint8_t high = static_cast(weight_qvals[i+1] + wzp); + packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF)); + } + + // Parameters for packing + rhs_packing::qparams_t qparams{ + .lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16}; + + auto rhs_pack = get_rhs_packing(); + + rhs_pack.run_rhs_pack( + /*groups=*/1, + n, + k, + ukernel.get_nr(), + ukernel.get_kr(), + ukernel.get_sr(), + group_size, + /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), + /*rhs_stride=*/roundup(k, 2) / 2, + /*bias=*/nullptr, // TODO fix APIs to move bias here + /*scale=*/reinterpret_cast(weight_scales_bf16.data()), + /*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size), + /*rhs_packed=*/weight_data, + /*extra_bytes=*/0, + /*qparams=*/&qparams); +} } // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p } // namespace torchao::kernels::cpu::aarch64::kleidi From c2727392478fd760e17b9aa7ea6890d7faa1e0f7 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Mon, 7 Oct 2024 22:25:04 -0500 Subject: [PATCH 10/19] [experimental][kleidi] Add 1x8x32 neon dotprod kernel --- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 125 ++++++++++++++++++ .../kernels/cpu/aarch64/tests/test_linear.cpp | 116 +++++++++++++++- 2 files changed, 240 insertions(+), 1 deletion(-) create mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h new file mode 100644 index 0000000000..6c47c4738d --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -0,0 +1,125 @@ +// namespace example +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { +namespace neon_dotprod_1x8x32 { +const Ukernel get_ukernel() { + return Ukernel{ + .get_m_step = + kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_n_step = + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_mr = + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_nr = + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_kr = + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_sr = + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_dst_offset = + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .get_dst_size = + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + .run_matmul = + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}; +} + +int activation_data_size(int m, int k, int group_size) { + (void) group_size; // unused + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k); +} + +void prepare_activation_data( + void* activation_data, + int m, + int k, + int group_size, + const float* activations) { + (void) group_size; // unused + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( + get_ukernel(), + activation_data, + m, + k, + activations); +} + +int weight_data_size(int n, int k, int group_size) { + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size); +} + +void prepare_weight_data( + void* weight_data, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros) { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( + get_ukernel(), + weight_data, + n, + k, + group_size, + weight_qvals, + weight_scales, + weight_zeros); +} + +void kernel( + float32_t* output, + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + const float* bias, + float clamp_min, + float clamp_max) { + (void) bias; // unused - needs API fixing + assert(output_m_stride == n); + if (clamp_min == 0 && clamp_max == 0) { + clamp_min = std::numeric_limits::lowest(); + clamp_max = std::numeric_limits::max(); + } + + auto ukernel = get_ukernel(); + ukernel.run_matmul( + m, + n, + k, + group_size, + activation_data, + weight_data, + output, + /*dst_stride_row=*/ n * sizeof(float), + /*dst_stride_col=*/ sizeof(float), + clamp_min, + clamp_max); +} + +size_t get_alignement() { + return 16; +} +} // namespace neon_dotprod_1x4x32 +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 90fee4ffd3..ec59df0347 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -6,13 +6,16 @@ #if defined(__aarch64__) || defined(__ARM_NEON) +#include #include + #include #include #include #include + #include -#include +#include float kTol = 0.0001; @@ -351,6 +354,9 @@ TEST( } } +// #ifdef TORCHAO_ENABLE_KLEIDI +// TODO: Wire up the the compile defination for TORCHAO_ENABLE_KLEIDI + template void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( int m, @@ -455,4 +461,112 @@ TEST( false /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } + + + +template +void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + int m, + int k, + int n, + int group_size) { + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp, + /*weight_scale_bf16_round_trip=*/true); + + using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; + + std::vector activation_data( + activation_data_size(m, k, group_size)); + + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data( + weight_data_size(n, k, group_size)); + + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data()); + + std::vector output(m * n); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*bias=*/test_case.bias.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + k_eq_gs_32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + large_k_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + even_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} +// #endif // defined(TORCHAO_ENABLE_KLEIDI) #endif // defined(__aarch64__) || defined(__ARM_NEON) From ee62be5d7862e968e2a21114e8b87adcf3efb695 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Mon, 7 Oct 2024 22:45:45 -0500 Subject: [PATCH 11/19] [experimental][kleidi] linter --- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 37 +++++++++---------- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 1 - .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 15 ++++---- .../kernels/cpu/aarch64/tests/test_linear.cpp | 22 +++++------ 4 files changed, 34 insertions(+), 41 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index 381ba9489d..a466010367 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -1,4 +1,3 @@ -// namespace example // Copyright (c) Meta Platforms, Inc. and affiliates. // All rights reserved. // @@ -42,8 +41,9 @@ const Ukernel get_ukernel() { } int activation_data_size(int m, int k, int group_size) { - (void) group_size; // unused - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k); + (void)group_size; // unused + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( + get_ukernel(), m, k); } void prepare_activation_data( @@ -52,17 +52,14 @@ void prepare_activation_data( int k, int group_size, const float* activations) { - (void) group_size; // unused + (void)group_size; // unused kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), - activation_data, - m, - k, - activations); + get_ukernel(), activation_data, m, k, activations); } int weight_data_size(int n, int k, int group_size) { - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size); + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( + get_ukernel(), n, k, group_size); } void prepare_weight_data( @@ -96,15 +93,15 @@ void kernel( const float* bias, float clamp_min, float clamp_max) { - (void) bias; // unused - needs API fixing - assert(output_m_stride == n); - if (clamp_min == 0 && clamp_max == 0) { - clamp_min = std::numeric_limits::lowest(); - clamp_max = std::numeric_limits::max(); - } + (void)bias; // unused - needs API fixing + assert(output_m_stride == n); + if (clamp_min == 0 && clamp_max == 0) { + clamp_min = std::numeric_limits::lowest(); + clamp_max = std::numeric_limits::max(); + } - auto ukernel = get_ukernel(); - ukernel.run_matmul( + auto ukernel = get_ukernel(); + ukernel.run_matmul( m, n, k, @@ -112,8 +109,8 @@ void kernel( activation_data, weight_data, output, - /*dst_stride_row=*/ n * sizeof(float), - /*dst_stride_col=*/ sizeof(float), + /*dst_stride_row=*/n * sizeof(float), + /*dst_stride_col=*/sizeof(float), clamp_min, clamp_max); } diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index 6c47c4738d..22afd0b808 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -1,4 +1,3 @@ -// namespace example // Copyright (c) Meta Platforms, Inc. and affiliates. // All rights reserved. // diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 71b7857e4d..d88e508fda 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -1,4 +1,3 @@ -// namespace example // Copyright (c) Meta Platforms, Inc. and affiliates. // All rights reserved. // @@ -7,10 +6,10 @@ #pragma once -#include +#include #include +#include #include -#include #include #include @@ -25,7 +24,7 @@ namespace torchao::kernels::cpu::aarch64::kleidi { // TODO: find a better place for these? size_t roundup(size_t a, size_t b) { - return ((a + b - 1) / b) * b; + return ((a + b - 1) / b) * b; } uint16_t get_bf16_from_float(float f) { @@ -111,13 +110,15 @@ void prepare_weight_data( uint8_t wzp = 8; for (size_t i = 0; i < n * k; i += 2) { const uint8_t low = static_cast(weight_qvals[i] + wzp); - const uint8_t high = static_cast(weight_qvals[i+1] + wzp); + const uint8_t high = static_cast(weight_qvals[i + 1] + wzp); packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF)); } // Parameters for packing rhs_packing::qparams_t qparams{ - .lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16}; + .lhs_zero_point = 1, + .rhs_zero_point = wzp, + .scale_dt = kai_datatype::kai_dt_bf16}; auto rhs_pack = get_rhs_packing(); @@ -133,7 +134,7 @@ void prepare_weight_data( /*rhs_stride=*/roundup(k, 2) / 2, /*bias=*/nullptr, // TODO fix APIs to move bias here /*scale=*/reinterpret_cast(weight_scales_bf16.data()), - /*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size), + /*scale_stride=*/sizeof(uint16_t) * (roundup(k, group_size) / group_size), /*rhs_packed=*/weight_data, /*extra_bytes=*/0, /*qparams=*/&qparams); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index ec59df0347..ca3ba90bdc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -6,8 +6,8 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include #include +#include #include #include @@ -375,10 +375,10 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( has_clamp, /*weight_scale_bf16_round_trip=*/true); - using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; + using namespace torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; - std::vector activation_data( - activation_data_size(m, k, group_size)); + std::vector activation_data(activation_data_size(m, k, group_size)); prepare_activation_data( (void*)activation_data.data(), @@ -387,8 +387,7 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( group_size, test_case.activations.data()); - std::vector weight_data( - weight_data_size(n, k, group_size)); + std::vector weight_data(weight_data_size(n, k, group_size)); prepare_weight_data( (void*)weight_data.data(), @@ -462,8 +461,6 @@ TEST( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } - - template void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( int m, @@ -482,10 +479,10 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( has_clamp, /*weight_scale_bf16_round_trip=*/true); - using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; + using namespace torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; - std::vector activation_data( - activation_data_size(m, k, group_size)); + std::vector activation_data(activation_data_size(m, k, group_size)); prepare_activation_data( (void*)activation_data.data(), @@ -494,8 +491,7 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( group_size, test_case.activations.data()); - std::vector weight_data( - weight_data_size(n, k, group_size)); + std::vector weight_data(weight_data_size(n, k, group_size)); prepare_weight_data( (void*)weight_data.data(), From ee49c6eb0e6153c28e32c2d8c2d31ca7b8a5dc03 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 8 Oct 2024 17:44:38 -0500 Subject: [PATCH 12/19] [experimental][kleidi] Reduce template types for tests --- .../kernels/cpu/aarch64/tests/test_linear.cpp | 46 ++++++++++--------- .../kernels/cpu/aarch64/tests/test_utils.h | 39 ---------------- 2 files changed, 24 insertions(+), 61 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index ca3ba90bdc..29072b3029 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -357,7 +357,7 @@ TEST( // #ifdef TORCHAO_ENABLE_KLEIDI // TODO: Wire up the the compile defination for TORCHAO_ENABLE_KLEIDI -template +template void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( int m, int k, @@ -369,8 +369,8 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( k, n, group_size, - weight_nbit, - has_weight_zeros, + /*weight_nbit=*/4, + /*has_weight_zeros*/false, has_bias, has_clamp, /*weight_scale_bf16_round_trip=*/true); @@ -421,8 +421,6 @@ TEST( test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, k_eq_gs_32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); @@ -432,8 +430,6 @@ TEST( test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, large_k_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); @@ -443,8 +439,6 @@ TEST( test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, even_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); @@ -454,14 +448,21 @@ TEST( test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } -template +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +template void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( int m, int k, @@ -473,8 +474,8 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( k, n, group_size, - weight_nbit, - has_weight_zeros, + /*weight_nbit=*/4, + /*has_weight_zeros=*/false, has_bias, has_clamp, /*weight_scale_bf16_round_trip=*/true); @@ -525,8 +526,6 @@ TEST( test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, k_eq_gs_32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); @@ -536,8 +535,6 @@ TEST( test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, large_k_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); @@ -547,8 +544,6 @@ TEST( test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, even_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); @@ -558,11 +553,18 @@ TEST( test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} // #endif // defined(TORCHAO_ENABLE_KLEIDI) #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index 4316f8f2d6..25d3337033 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -235,7 +235,6 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { zero, qmin, qmax); - // std::fill(weight_qvals.begin(), weight_qvals.end(), -7); } std::vector bias(m, 0.0); @@ -277,44 +276,6 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { } } -#if 0 // Alternate reference implementation for debugging. - auto num_groups = k / weight_group_size; - for (int m_idx = 0; m_idx < m; m_idx++) { - for (int n_idx = 0; n_idx < n; n_idx++) { - int32_t result_idx = m_idx * n + n_idx; - float weights_fsum = 0.0; - for (int g_idx = 0; g_idx < num_groups; g_idx++) { - int32_t weights_qsum = 0; - int32_t acc_i32 = 0; - for (int k_idx = 0; k_idx < weight_group_size; k_idx++) { - const int32_t activation_idx = m_idx * k + g_idx * weight_group_size + k_idx; - const int32_t weight_idx = n_idx * k + g_idx * weight_group_size + k_idx; - - const int32_t weight_qval = weight_qvals[weight_idx]; - const int32_t activation_qval = activation_qvals[activation_idx]; - - weights_qsum += weight_qval; - acc_i32 += weight_qval * activation_qval; - } - // For each group, we have a weight scale - const int32_t weight_scale_idx = n_idx * num_groups + g_idx; - const float weight_scale = weight_scales[weight_scale_idx]; // already rounded trip to bf16 - expected_output[result_idx] += (float) acc_i32 * weight_scales[weight_scale_idx]; - weights_fsum += weights_qsum * weight_scale; - } - // For each output channel, we have an activation scale - const int32_t activation_zero_point = activation_zeros[m_idx]; - const float activation_scale = activation_scales[m_idx]; - expected_output[result_idx] -= activation_zero_point * weights_fsum; - expected_output[result_idx] *= activation_scale; - expected_output[result_idx] += bias[m_idx]; - if (has_clamp) { - expected_output[result_idx] = std::min(std::max(expected_output[result_idx], clamp_min), clamp_max); - } - } - } -#endif - // Return test case return channelwise_8bit_activation_groupwise_lowbit_weight_test_case( m, From a905ec3154aff1d1cc6f638506b011243669c9d2 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 10 Oct 2024 10:09:11 -0500 Subject: [PATCH 13/19] [experimental][kleidi] Add m>1 tests --- .../kernels/cpu/aarch64/tests/test_linear.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 29072b3029..c3135b742b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -462,6 +462,16 @@ TEST( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + m_clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + + template void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( int m, @@ -566,5 +576,14 @@ TEST( true /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } + +TEST( + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + m_clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); +} // #endif // defined(TORCHAO_ENABLE_KLEIDI) #endif // defined(__aarch64__) || defined(__ARM_NEON) From 7429beaaa9be766029ec0d20028dac2d409cfc20 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 10 Oct 2024 10:18:37 -0500 Subject: [PATCH 14/19] [experimental][kleidi] rename bf16 weight scale flag --- .../experimental/kernels/cpu/aarch64/tests/test_linear.cpp | 2 +- torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index c3135b742b..8b0b013168 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -488,7 +488,7 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( /*has_weight_zeros=*/false, has_bias, has_clamp, - /*weight_scale_bf16_round_trip=*/true); + /*round_weight_scales_to_bf16=*/true); using namespace torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index 25d3337033..c684405446 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -156,7 +156,7 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { bool has_weight_zeros, bool has_bias, bool has_clamp, - bool weight_scale_bf16_round_trip=false) { + bool round_weight_scales_to_bf16=false) { // activations is m x k (stored in row-major) // weights is k x n (stored in column-major) @@ -219,7 +219,7 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { scale = torchao::quantization::get_scale(vmin, vmax, qmin, qmax); zero = 0; } - if (weight_scale_bf16_round_trip) { + if (round_weight_scales_to_bf16) { // weight scales are bf16 in the kernel // so we need to round trip them to bf16 and back to float to match it. scale = get_float_from_bf16(get_bf16_from_float(scale)); From f28e556f3085c8c41af5618e5ea371439edc25e7 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 10 Oct 2024 10:19:13 -0500 Subject: [PATCH 15/19] [experimental][kleidi] Build kernel tests in debug mode --- .../kernels/cpu/aarch64/tests/build_and_run_tests.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 98ce559914..27c584ce20 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -9,9 +9,9 @@ set -eu SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests -B ${CMAKE_OUT} +cmake -DCMAKE_BUILD_TYPE=Debug -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests -B ${CMAKE_OUT} -cmake --build ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} # Run ${CMAKE_OUT}/test_quantization From 17f2b43fa57e0416d93721d879172756c1058952 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 10 Oct 2024 13:01:20 -0500 Subject: [PATCH 16/19] [experimental][kleidi] Add TODO tasks --- ...mul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 2 +- ...mul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 2 +- .../aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index a466010367..d62b4b7c0f 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -93,7 +93,7 @@ void kernel( const float* bias, float clamp_min, float clamp_max) { - (void)bias; // unused - needs API fixing + (void)bias; // TODO(T203756650) - unused - needs API fixing assert(output_m_stride == n); if (clamp_min == 0 && clamp_max == 0) { clamp_min = std::numeric_limits::lowest(); diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index 22afd0b808..6fc2ccbdb5 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -94,7 +94,7 @@ void kernel( const float* bias, float clamp_min, float clamp_max) { - (void) bias; // unused - needs API fixing + (void) bias; // TODO(T203756650) - unused - needs API fixing assert(output_m_stride == n); if (clamp_min == 0 && clamp_max == 0) { clamp_min = std::numeric_limits::lowest(); diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index d88e508fda..c7ed5bc8f3 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -90,7 +90,7 @@ void prepare_weight_data( const int8_t* weight_qvals, const float* weight_scales, const int8_t* weight_zeros) { - // TODO - remove this constraint and pad when possible + // TODO(T204312268) - remove this constraint and pad when possible assert(n % 2 == 0); assert(group_size % 32 == 0); @@ -132,7 +132,7 @@ void prepare_weight_data( group_size, /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), /*rhs_stride=*/roundup(k, 2) / 2, - /*bias=*/nullptr, // TODO fix APIs to move bias here + /*bias=*/nullptr, // TODO(T203756650) fix APIs to move bias here /*scale=*/reinterpret_cast(weight_scales_bf16.data()), /*scale_stride=*/sizeof(uint16_t) * (roundup(k, group_size) / group_size), /*rhs_packed=*/weight_data, From 3049ded51266f3c1c285bc65a53d9ce651cd6a79 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 10 Oct 2024 13:02:32 -0500 Subject: [PATCH 17/19] [experimental][kleidi] Allow weight zeros to be a nullptr --- .../kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index c7ed5bc8f3..71426370af 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -99,8 +99,15 @@ void prepare_weight_data( // TODO SIMDify this size_t n_groups = n * k / group_size; auto weight_scales_bf16 = std::vector(n_groups, 0); + + // We don't support weight zeros yet + if (weight_zeros != nullptr) { + for (size_t i = 0; i < n_groups; i++) { + assert(weight_zeros[i] == 0); + } + } + for (size_t i = 0; i < n_groups; i++) { - assert(weight_zeros[i] == 0); weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]); } From d4bb3ed08a21ccf06bf6f89d2825ccbd5211cae4 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 10 Oct 2024 13:53:40 -0500 Subject: [PATCH 18/19] [experimental][kleidi] rebase fixes with int to size_t --- ...l_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 6 +++--- ...l_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 6 +++--- .../aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 4 ++-- torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h | 6 ++++-- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index d62b4b7c0f..ded42df9db 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -40,7 +40,7 @@ const Ukernel get_ukernel() { kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; } -int activation_data_size(int m, int k, int group_size) { +size_t activation_data_size(int m, int k, int group_size) { (void)group_size; // unused return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( get_ukernel(), m, k); @@ -57,7 +57,7 @@ void prepare_activation_data( get_ukernel(), activation_data, m, k, activations); } -int weight_data_size(int n, int k, int group_size) { +size_t weight_data_size(int n, int k, int group_size) { return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( get_ukernel(), n, k, group_size); } @@ -115,7 +115,7 @@ void kernel( clamp_max); } -size_t get_alignement() { +size_t get_preferred_alignement() { return 16; } } // namespace neon_dotprod_1x4x32 diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index 6fc2ccbdb5..116f25bd59 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -39,7 +39,7 @@ const Ukernel get_ukernel() { kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}; } -int activation_data_size(int m, int k, int group_size) { +size_t activation_data_size(int m, int k, int group_size) { (void) group_size; // unused return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k); } @@ -59,7 +59,7 @@ void prepare_activation_data( activations); } -int weight_data_size(int n, int k, int group_size) { +size_t weight_data_size(int n, int k, int group_size) { return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size); } @@ -116,7 +116,7 @@ void kernel( clamp_max); } -size_t get_alignement() { +size_t get_preferred_alignement() { return 16; } } // namespace neon_dotprod_1x4x32 diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 71426370af..ae971f454f 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -43,7 +43,7 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; -int activation_data_size(const Ukernel ukernel, int m, int k) { +size_t activation_data_size(const Ukernel ukernel, int m, int k) { auto lhs_packing = get_lhs_packing(); return lhs_packing.get_lhs_packed_size( m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); @@ -69,7 +69,7 @@ void prepare_activation_data( activation_data); } -int weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { +size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { auto rhs_pack = get_rhs_packing(); return rhs_pack.get_rhs_packed_size( n, diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index c684405446..e9f36e14ac 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -44,7 +44,8 @@ inline std::vector get_random_lowbit_vector(int size, int nbit) { } // TODO move these to a common utils -uint16_t get_bf16_from_float(float f) { +inline uint16_t +get_bf16_from_float(float f) { uint16_t bf16; #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ memcpy(&bf16, &f, sizeof(uint16_t)); @@ -56,7 +57,8 @@ uint16_t get_bf16_from_float(float f) { return bf16; } -float get_float_from_bf16(uint16_t bf16) { +inline float +get_float_from_bf16(uint16_t bf16) { float f; const uint32_t i32 = (bf16 << 16); memcpy(&f, &i32, sizeof(uint32_t)); From f6e22fb5e51df8d2337b8d6bbc55638c75058e2c Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 10 Oct 2024 19:05:03 -0500 Subject: [PATCH 19/19] [experimental][kleidi] compile-time preprocessor switch for kleidi tests --- torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt | 5 +++-- .../experimental/kernels/cpu/aarch64/tests/CMakeLists.txt | 7 +++++++ .../kernels/cpu/aarch64/tests/test_linear.cpp | 8 ++++---- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index da99a55ed1..6541d7fdac 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -20,8 +20,6 @@ FetchContent_MakeAvailable(kleidiai) CMAKE_DEPENDENT_OPTION(BUILD_KLEIDI "Download, build, and link against Arm KleidiAI library" OFF "CMAKE_SYSTEM_PROCESSOR STREQUAL \"arm64\"" ON) -add_compile_definitions("TORCHAO_ENABLE_KLEIDI=$") - if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") add_library( torchao_kernels_aarch64 @@ -31,6 +29,9 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp ) if (BUILD_KLEIDI) + # Temporarily exposing this to the parent scope until we wire + # this up properly from the top level + set(TORCHAO_ENABLE_KLEIDI ON PARENT_SCOPE) message(STATUS "Building with Kleidi") target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) endif() diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 29f38b57fa..c9799eadd9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -32,8 +32,15 @@ add_library( if(NOT TORCHAO_INCLUDE_DIRS) set(TORCHAO_INCLUDE_DIRS ${TORCHAO_LIBRARIES}) endif() + add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) +# The TORCHAO_ENABLE_KLEIDI cmake variable should be set by `torchao_kernels_aarch64" +# This is a temporary work around. +if(TORCHAO_ENABLE_KLEIDI) + add_compile_definitions(TORCHAO_ENABLE_KLEIDI) +endif() + enable_testing() add_executable(test_quantization test_quantization.cpp) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 8b0b013168..571b91d476 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -14,8 +14,10 @@ #include #include +#ifdef TORCHAO_ENABLE_KLEIDI #include #include +#endif float kTol = 0.0001; @@ -354,9 +356,7 @@ TEST( } } -// #ifdef TORCHAO_ENABLE_KLEIDI -// TODO: Wire up the the compile defination for TORCHAO_ENABLE_KLEIDI - +#ifdef TORCHAO_ENABLE_KLEIDI template void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( int m, @@ -585,5 +585,5 @@ TEST( true /*has_clamp*/>( /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); } -// #endif // defined(TORCHAO_ENABLE_KLEIDI) +#endif // TORCHAO_ENABLE_KLEIDI #endif // defined(__aarch64__) || defined(__ARM_NEON)