Skip to content

Commit 41b38f7

Browse files
nikhil-armpytorchmergebot
authored andcommitted
Revert "Reverting the PR adding Kleidiai-based int4 kernels (pytorch#145392)" (pytorch#145505)
pytorch#134124 was reverted by pytorch#145392 due to KleidiAI clone issue. 1. This reverts commit 0940eb6 (pytorch#145392 )and Fixes KleidiAI mirror issue. 2. KleidiAI is now cloned from github mirror instead of arm gitlab Change-Id: I7d6eee7214cd117d3057d615936fcc3ee6052fa2 Fixes pytorch#145273 Pull Request resolved: pytorch#145505 Approved by: https://github.com/malfet
1 parent 34b8d8b commit 41b38f7

37 files changed

+1940
-23
lines changed

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,6 @@
131131
path = third_party/composable_kernel
132132
url = https://github.com/ROCm/composable_kernel.git
133133
branch = develop
134+
[submodule "third_party/kleidiai"]
135+
path = third_party/kleidiai
136+
url = https://github.com/ARM-software/kleidiai.git

BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ filegroup(
257257
# target that generates these sources...
258258
)
259259

260+
# TODO: Enable support for KleidiAI bazel build
260261
header_template_rule(
261262
name = "aten_src_ATen_config",
262263
src = "aten/src/ATen/Config.h.in",
@@ -276,6 +277,7 @@ header_template_rule(
276277
"@AT_PARALLEL_NATIVE@": "1",
277278
"@AT_BLAS_F2C@": "0",
278279
"@AT_BLAS_USE_CBLAS_DOT@": "1",
280+
"@AT_KLEIDIAI_ENABLED@": "0",
279281
},
280282
)
281283

CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ cmake_dependent_option(
377377
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
378378
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
379379
OFF "USE_CUDA" OFF)
380+
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
381+
"CPU_AARCH64" OFF)
380382

381383
option(USE_MIMALLOC "Use mimalloc" OFF)
382384
# Enable third party mimalloc library to improve memory allocation performance
@@ -418,6 +420,8 @@ endif()
418420
if(WIN32)
419421
set(USE_TENSORPIPE OFF)
420422
message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF")
423+
set(USE_KLEIDIAI OFF)
424+
message(WARNING "KleidiAI cannot be used on Windows. Set it to OFF")
421425

422426
if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT})
423427
find_library(
@@ -667,6 +671,9 @@ if(ANDROID
667671
message(WARNING "INTERN_BUILD_MOBILE is on, disabling BUILD_LAZY_TS_BACKEND")
668672
set(BUILD_LAZY_TS_BACKEND OFF)
669673

674+
set(USE_KLEIDIAI OFF)
675+
message(WARNING "KleidiAI cannot be used on Mobile builds. Set it to OFF")
676+
670677
# Set -ffunction-sections and -fdata-sections so that each method has its own
671678
# text section. This allows the linker to remove unused section when the flag
672679
# -Wl,-gc-sections is provided at link time.

WORKSPACE

+6
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,12 @@ local_repository(
309309
path = "third_party/gemmlowp/gemmlowp",
310310
)
311311

312+
local_repository(
313+
name = "kleidiai",
314+
path = "third_party/kleidiai",
315+
repo_mapping = {"@com_google_googletest": "@com_google_benchmark"},
316+
)
317+
312318
### Unused repos start
313319

314320
# `unused` repos are defined to hide bazel files from submodules of submodules.

aten/src/ATen/CMakeLists.txt

+9-1
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ endif()
219219
# XNNPACK
220220
file(GLOB native_xnnpack "native/xnnpack/*.cpp")
221221

222+
# KLEIDIAI
223+
file(GLOB native_kleidiai "native/kleidiai/*.cpp")
224+
file(GLOB native_kleidiai_h "native/kleidiai/*.h")
225+
222226
# Add files needed from jit folders
223227
append_filelist("jit_core_headers" ATen_CORE_HEADERS)
224228
append_filelist("jit_core_sources" ATen_CORE_SRCS)
@@ -248,6 +252,10 @@ endif()
248252
if(AT_MKL_ENABLED)
249253
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
250254
endif()
255+
if(AT_KLEIDIAI_ENABLED)
256+
set(all_cpu_cpp ${all_cpu_cpp} ${native_kleidiai})
257+
include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS})
258+
endif()
251259
if(AT_MKLDNN_ENABLED)
252260
set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp})
253261
endif()
@@ -637,7 +645,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"
637645

638646
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS})
639647
if(NOT INTERN_BUILD_MOBILE)
640-
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h})
648+
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h})
641649
# Metal
642650
if(USE_PYTORCH_METAL_EXPORT)
643651
# Add files needed from exporting metal models(optimized_for_mobile)

aten/src/ATen/Config.h.in

+1
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
#define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@
2020
#define AT_BLAS_F2C() @AT_BLAS_F2C@
2121
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
22+
#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@

aten/src/ATen/Context.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ bool Context::hasMKLDNN() {
435435
#endif
436436
}
437437

438+
bool Context::hasKleidiAI() {
439+
return AT_KLEIDIAI_ENABLED();
440+
}
441+
438442
bool Context::hasOpenMP() {
439443
#ifdef _OPENMP
440444
return true;

aten/src/ATen/Context.h

+5
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class TORCH_API Context {
119119

120120
static bool hasOpenMP();
121121
static bool hasMKL();
122+
static bool hasKleidiAI();
122123
static bool hasLAPACK();
123124
static bool hasMKLDNN();
124125
static bool hasMAGMA() {
@@ -550,6 +551,10 @@ inline bool hasMKL() {
550551
return globalContext().hasMKL();
551552
}
552553

554+
inline bool hasKleidiAI() {
555+
return globalContext().hasKleidiAI();
556+
}
557+
553558
inline bool hasLAPACK() {
554559
return globalContext().hasLAPACK();
555560
}

aten/src/ATen/native/LinearAlgebra.cpp

+67
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
#include <ATen/ops/_addmm_activation_native.h>
3434
#include <ATen/ops/_compute_linear_combination_native.h>
3535
#include <ATen/ops/_convert_weight_to_int4pack_for_cpu_native.h>
36+
#include <ATen/ops/_dyn_quant_matmul_4bit_native.h>
37+
#include <ATen/ops/_dyn_quant_pack_4bit_weight_native.h>
3638
#include <ATen/ops/_int_mm_native.h>
3739
#include <ATen/ops/_linalg_check_errors.h>
3840
#include <ATen/ops/_linalg_det.h>
@@ -3429,6 +3431,8 @@ Tensor kron(const Tensor& self, const Tensor& other) {
34293431
DEFINE_DISPATCH(weight_to_int4pack_stub);
34303432
DEFINE_DISPATCH(int4pack_mm_stub);
34313433
DEFINE_DISPATCH(int8pack_mm_stub);
3434+
DEFINE_DISPATCH(dyn_quant_pack_4bit_weight_stub);
3435+
DEFINE_DISPATCH(dyn_quant_matmul_4bit_stub);
34323436

34333437
Tensor _convert_weight_to_int4pack_cpu(
34343438
const Tensor& in,
@@ -3492,6 +3496,69 @@ Tensor _weight_int4pack_mm_cpu(
34923496
return C;
34933497
}
34943498

3499+
Tensor _dyn_quant_pack_4bit_weight_cpu(
3500+
const Tensor& weights,
3501+
const Tensor& scales_zeros,
3502+
const std::optional<Tensor>& bias,
3503+
const int64_t block_size,
3504+
const int64_t in_features,
3505+
const int64_t out_features) {
3506+
TORCH_CHECK(
3507+
weights.dtype() == at::kByte, __func__, " : expect weight to be kByte.");
3508+
TORCH_CHECK(
3509+
block_size == in_features ||
3510+
(!(block_size % 32) && !(in_features % block_size)),
3511+
__func__,
3512+
": Group size should be multiple of 32, in_features [",
3513+
in_features,
3514+
"]. Provided ",
3515+
block_size);
3516+
Tensor packed_weights =
3517+
at::empty(weights.sizes(), weights.options().dtype(at::kByte));
3518+
dyn_quant_pack_4bit_weight_stub(
3519+
kCPU,
3520+
packed_weights,
3521+
weights,
3522+
scales_zeros,
3523+
bias,
3524+
out_features,
3525+
in_features,
3526+
block_size);
3527+
return packed_weights;
3528+
}
3529+
3530+
Tensor _dyn_quant_matmul_4bit_cpu(
3531+
const Tensor& inp,
3532+
const Tensor& packed_weights,
3533+
const int64_t block_size,
3534+
const int64_t in_features,
3535+
const int64_t out_features) {
3536+
auto M = inp.size(0);
3537+
TORCH_CHECK(
3538+
inp.dtype() == kFloat,
3539+
__func__,
3540+
" : expect input to be 32-bit float tensor.");
3541+
TORCH_CHECK(
3542+
block_size == in_features ||
3543+
(!(block_size % 32) && !(in_features % block_size)),
3544+
__func__,
3545+
": Group size should be multiple of 32, in_features [",
3546+
in_features,
3547+
"]. Provided ",
3548+
block_size);
3549+
auto output = at::empty({M, out_features}, inp.options());
3550+
dyn_quant_matmul_4bit_stub(
3551+
kCPU,
3552+
output,
3553+
inp,
3554+
packed_weights,
3555+
M,
3556+
out_features,
3557+
in_features,
3558+
block_size);
3559+
return output;
3560+
}
3561+
34953562
Tensor _weight_int8pack_mm_cpu(
34963563
const Tensor& A,
34973564
const Tensor& B,

0 commit comments

Comments
 (0)