From 9951eb5c7eb63b0b2acbc8b1c46f1f3a51aa0302 Mon Sep 17 00:00:00 2001 From: Lorenzo Dematte Date: Tue, 14 May 2024 15:12:22 +0200 Subject: [PATCH 1/9] Add vec_caps and inner implementation for AVX-512-F (without VNNI) --- libs/vec/native/src/vec/c/amd64/vec.c | 66 ++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/libs/vec/native/src/vec/c/amd64/vec.c b/libs/vec/native/src/vec/c/amd64/vec.c index c9a49ad2d1d4d..79b2d1cd7d27f 100644 --- a/libs/vec/native/src/vec/c/amd64/vec.c +++ b/libs/vec/native/src/vec/c/amd64/vec.c @@ -67,9 +67,19 @@ EXPORT int vec_caps() { if (functionIds >= 7) { cpuid(cpuInfo, 7); int ebx = cpuInfo[1]; + int ecx = cpuInfo[2]; // AVX2 flag is the 5th bit // We assume that all processors that have AVX2 also have FMA3 - return (ebx & (1 << 5)) != 0; + int avx2 = (ebx & 0x00000020) != 0; + int avx512 = (ebx & 0x00010000) != 0; + // int avx512_vnni = (ecx & 0x00000800) != 0; + // if (avx512 && avx512_vnni) { + if (avx512) { + return 2; + } + if (avx2) { + return 1; + } } return 0; } @@ -111,6 +121,34 @@ EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) { return res; } +static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { + const __m256i ones = _mm256_set1_epi16(1); + + // Init accumulator(s) with 0 + __m512i acc1 = _mm512_setzero_si512(); + +#pragma GCC unroll 4 + for(int i = 0; i < dims; i += sizeof(__m512i)) { + // Load 32 packed 8-bit integers + __m512i va1 = _mm512_loadu_si512(a + i); + __m512i vb1 = _mm512_loadu_si512(b + i); + + // Perform multiplication and create 16-bit values + // Vertically multiply each unsigned 8-bit integer from va with the corresponding + // signed 8-bit integer from vb, producing intermediate signed 16-bit integers. + // These values will be at max 32385, at min −32640, + // Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results. + + // VNNI + //acc1 = _mm512_dpbusd_epi32(acc1, va1, vb1); + const __m512i vab = _mm512_maddubs_epi16(va, vb); + acc1 = _mm512_add_epi32(_mm512_madd_epi16(ones, vab), acc1); + } + + // reduce (accumulate all) + return _mm512_reduce_add_epi32(acc1); +} + static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) { // Init accumulator(s) with 0 __m256i acc1 = _mm256_setzero_si256(); @@ -126,7 +164,6 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) { const __m256i dist1 = _mm256_sub_epi8(va1, vb1); const __m256i abs_dist1 = _mm256_sign_epi8(dist1, dist1); const __m256i sqr1 = _mm256_maddubs_epi16(abs_dist1, abs_dist1); - acc1 = _mm256_add_epi32(_mm256_madd_epi16(ones, sqr1), acc1); } @@ -148,3 +185,28 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { return res; } +static inline int32_t sqr7u_avx512(int8_t *a, int8_t *b, size_t dims) { + // Init accumulator(s) with 0 + __m512i acc1 = _mm512_setzero_si512(); + + const __m512i ones = _mm512_set1_epi16(1); + +#pragma GCC unroll 4 + for(int i = 0; i < dims; i += sizeof(__m512i)) { + // Load packed 8-bit integers + __m512i va = _mm512_loadu_si512(a + i); + __m512i vb = _mm512_loadu_si512(b + i); + + const __m512i dist = _mm512_sub_epi8(va, vb); + const __m512i abs_dist = _mm512_sign_epi8(dist, dist); + + // VNNI + //acc1 = _mm512_dpbusd_epi32(acc1, abs_dist, abs_dist); + const __m512i sqr = _mm512_maddubs_epi16(abs_dist, abs_dist); + acc1 = _mm512_add_epi32(_mm512_madd_epi16(ones, sqr), acc1); + } + + // reduce (accumulate all) + return _mm512_reduce_add_epi32(acc1); +} + From 98e677fde58afd08c84c22a9af2a00ae4a187a58 Mon Sep 17 00:00:00 2001 From: Lorenzo Dematte Date: Thu, 23 May 2024 10:16:29 +0200 Subject: [PATCH 2/9] WIP --- .../nativeaccess/jdk/JdkVectorLibrary.java | 11 +- libs/vec/native/Dockerfile | 2 +- libs/vec/native/src/vec/c/amd64/vec.c | 58 +-------- libs/vec/native/src/vec/c/amd64/vec_2.c | 120 ++++++++++++++++++ 4 files changed, 132 insertions(+), 59 deletions(-) create mode 100644 libs/vec/native/src/vec/c/amd64/vec_2.c diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index e49b1985d6431..379b813e4657d 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -23,6 +23,9 @@ public final class JdkVectorLibrary implements VectorLibrary { + static final String DOT7U_NAME; + static final String SQR7U_NAME; + static final VectorSimilarityFunctions INSTANCE; static { @@ -32,8 +35,12 @@ public final class JdkVectorLibrary implements VectorLibrary { try { int caps = (int) vecCaps$mh.invokeExact(); if (caps != 0) { + DOT7U_NAME = "dot7u"; + SQR7U_NAME = "sqr7u"; INSTANCE = new JdkVectorSimilarityFunctions(); } else { + DOT7U_NAME = null; + SQR7U_NAME = null; INSTANCE = null; } } catch (Throwable t) { @@ -50,8 +57,8 @@ public VectorSimilarityFunctions getVectorSimilarityFunctions() { private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions { - static final MethodHandle dot7u$mh = downcallHandle("dot7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); - static final MethodHandle sqr7u$mh = downcallHandle("sqr7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + static final MethodHandle dot7u$mh = downcallHandle(DOT7U_NAME, FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + static final MethodHandle sqr7u$mh = downcallHandle(SQR7U_NAME, FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); /** * Computes the dot product of given unsigned int7 byte vectors. diff --git a/libs/vec/native/Dockerfile b/libs/vec/native/Dockerfile index 66eb7e92ef479..607e206615ad1 100644 --- a/libs/vec/native/Dockerfile +++ b/libs/vec/native/Dockerfile @@ -4,7 +4,7 @@ RUN apt update RUN apt install -y gcc g++ openjdk-17-jdk COPY . /workspace WORKDIR /workspace -RUN ./gradlew --quiet --console=plain clean buildSharedLibrary +RUN ./gradlew --console=plain clean buildSharedLibrary RUN strip --strip-unneeded build/output/libvec.so CMD cat build/output/libvec.so diff --git a/libs/vec/native/src/vec/c/amd64/vec.c b/libs/vec/native/src/vec/c/amd64/vec.c index 79b2d1cd7d27f..654009ef02f99 100644 --- a/libs/vec/native/src/vec/c/amd64/vec.c +++ b/libs/vec/native/src/vec/c/amd64/vec.c @@ -14,11 +14,11 @@ #include #ifndef DOT7U_STRIDE_BYTES_LEN -#define DOT7U_STRIDE_BYTES_LEN 32 // Must be a power of 2 +#define DOT7U_STRIDE_BYTES_LEN sizeof(__m256i) // Must be a power of 2 #endif #ifndef SQR7U_STRIDE_BYTES_LEN -#define SQR7U_STRIDE_BYTES_LEN 32 // Must be a power of 2 +#define SQR7U_STRIDE_BYTES_LEN sizeof(__m256i) // Must be a power of 2 #endif #ifdef _MSC_VER @@ -121,34 +121,6 @@ EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) { return res; } -static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { - const __m256i ones = _mm256_set1_epi16(1); - - // Init accumulator(s) with 0 - __m512i acc1 = _mm512_setzero_si512(); - -#pragma GCC unroll 4 - for(int i = 0; i < dims; i += sizeof(__m512i)) { - // Load 32 packed 8-bit integers - __m512i va1 = _mm512_loadu_si512(a + i); - __m512i vb1 = _mm512_loadu_si512(b + i); - - // Perform multiplication and create 16-bit values - // Vertically multiply each unsigned 8-bit integer from va with the corresponding - // signed 8-bit integer from vb, producing intermediate signed 16-bit integers. - // These values will be at max 32385, at min −32640, - // Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results. - - // VNNI - //acc1 = _mm512_dpbusd_epi32(acc1, va1, vb1); - const __m512i vab = _mm512_maddubs_epi16(va, vb); - acc1 = _mm512_add_epi32(_mm512_madd_epi16(ones, vab), acc1); - } - - // reduce (accumulate all) - return _mm512_reduce_add_epi32(acc1); -} - static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) { // Init accumulator(s) with 0 __m256i acc1 = _mm256_setzero_si256(); @@ -184,29 +156,3 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { } return res; } - -static inline int32_t sqr7u_avx512(int8_t *a, int8_t *b, size_t dims) { - // Init accumulator(s) with 0 - __m512i acc1 = _mm512_setzero_si512(); - - const __m512i ones = _mm512_set1_epi16(1); - -#pragma GCC unroll 4 - for(int i = 0; i < dims; i += sizeof(__m512i)) { - // Load packed 8-bit integers - __m512i va = _mm512_loadu_si512(a + i); - __m512i vb = _mm512_loadu_si512(b + i); - - const __m512i dist = _mm512_sub_epi8(va, vb); - const __m512i abs_dist = _mm512_sign_epi8(dist, dist); - - // VNNI - //acc1 = _mm512_dpbusd_epi32(acc1, abs_dist, abs_dist); - const __m512i sqr = _mm512_maddubs_epi16(abs_dist, abs_dist); - acc1 = _mm512_add_epi32(_mm512_madd_epi16(ones, sqr), acc1); - } - - // reduce (accumulate all) - return _mm512_reduce_add_epi32(acc1); -} - diff --git a/libs/vec/native/src/vec/c/amd64/vec_2.c b/libs/vec/native/src/vec/c/amd64/vec_2.c new file mode 100644 index 0000000000000..1f3677585049f --- /dev/null +++ b/libs/vec/native/src/vec/c/amd64/vec_2.c @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +#include +#include +#include "vec.h" + +#ifdef _MSC_VER +#include +#elif __GNUC__ +#pragma GCC push_options +#pragma GCC target ("arch=skylake-avx512") +#include +#elif __clang__ +#pragma clang attribute push (__attribute__((target("arch=skylake-avx512"))), apply_to=function) +#include +#endif + +#include +#include + +#ifndef DOT7U_STRIDE_BYTES_LEN +#define DOT7U_STRIDE_BYTES_LEN sizeof(__m512i) // Must be a power of 2 +#endif + +#ifndef SQR7U_STRIDE_BYTES_LEN +#define SQR7U_STRIDE_BYTES_LEN sizeof(__m512i) // Must be a power of 2 +#endif + +static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { + const __m512i ones = _mm512_set1_epi16(1); + + // Init accumulator(s) with 0 + __m512i acc1 = _mm512_setzero_si512(); + +#pragma GCC unroll 4 + for(int i = 0; i < dims; i += DOT7U_STRIDE_BYTES_LEN) { + // Load 32 packed 8-bit integers + __m512i va = _mm512_loadu_si512(a + i); + __m512i vb = _mm512_loadu_si512(b + i); + + // Perform multiplication and create 16-bit values + // Vertically multiply each unsigned 8-bit integer from va with the corresponding + // signed 8-bit integer from vb, producing intermediate signed 16-bit integers. + // These values will be at max 32385, at min −32640, + // Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results. + + // VNNI + //acc1 = _mm512_dpbusd_epi32(acc1, va1, vb1); + const __m512i vab = _mm512_maddubs_epi16(va, vb); + acc1 = _mm512_add_epi32(_mm512_madd_epi16(ones, vab), acc1); + } + + // reduce (accumulate all) + return _mm512_reduce_add_epi32(acc1); +} + +EXPORT int32_t dot7u_2(int8_t* a, int8_t* b, size_t dims) { + int32_t res = 0; + int i = 0; + if (dims > DOT7U_STRIDE_BYTES_LEN) { + i += dims & ~(DOT7U_STRIDE_BYTES_LEN - 1); + res = dot7u_inner_avx512(a, b, i); + } + for (; i < dims; i++) { + res += a[i] * b[i]; + } + return res; +} + +static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { + const __m512i ones = _mm512_set1_epi16(1); + + // Init accumulator(s) with 0 + __m512i acc1 = _mm512_setzero_si512(); + +#pragma GCC unroll 4 + for(int i = 0; i < dims; i += SQR7U_STRIDE_BYTES_LEN) { + // Load packed 8-bit integers + __m512i va = _mm512_loadu_si512(a + i); + __m512i vb = _mm512_loadu_si512(b + i); + + const __m512i dist = _mm512_sub_epi8(va, vb); + const __m512i abs_dist = _mm512_abs_epi8(dist); + + // VNNI + //acc1 = _mm512_dpbusd_epi32(acc1, abs_dist, abs_dist); + const __m512i sqr = _mm512_maddubs_epi16(abs_dist, abs_dist); + acc1 = _mm512_add_epi32(_mm512_madd_epi16(ones, sqr), acc1); + } + + // reduce (accumulate all) + return _mm512_reduce_add_epi32(acc1); +} + +EXPORT int32_t sqr7u_2(int8_t* a, int8_t* b, size_t dims) { + int32_t res = 0; + int i = 0; + if (dims > SQR7U_STRIDE_BYTES_LEN) { + i += dims & ~(SQR7U_STRIDE_BYTES_LEN - 1); + res = sqr7u_inner_avx512(a, b, i); + } + for (; i < dims; i++) { + int32_t dist = a[i] - b[i]; + res += dist * dist; + } + return res; +} + +#ifdef __GNUC__ +#pragma GCC pop_options +#elif __clang__ +#pragma clang attribute pop +#endif + From 866199ce6bdbf71088a76bbf4ef078000802f6fe Mon Sep 17 00:00:00 2001 From: Lorenzo Dematte Date: Mon, 27 May 2024 18:12:09 +0200 Subject: [PATCH 3/9] select FNNI function name based on vec_caps; templated implementation for manual unrolling --- .../nativeaccess/jdk/JdkVectorLibrary.java | 25 ++++-- libs/vec/native/build.gradle | 12 +++ .../src/vec/c/amd64/{vec_2.c => vec_2.cpp} | 81 ++++++++++++++----- 3 files changed, 92 insertions(+), 26 deletions(-) rename libs/vec/native/src/vec/c/amd64/{vec_2.c => vec_2.cpp} (51%) diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index 379b813e4657d..e720cf2be008c 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -23,8 +23,8 @@ public final class JdkVectorLibrary implements VectorLibrary { - static final String DOT7U_NAME; - static final String SQR7U_NAME; + static final String DOT7U_FUNCTION_NAME; + static final String SQR7U_FUNCTION_NAME; static final VectorSimilarityFunctions INSTANCE; @@ -35,12 +35,17 @@ public final class JdkVectorLibrary implements VectorLibrary { try { int caps = (int) vecCaps$mh.invokeExact(); if (caps != 0) { - DOT7U_NAME = "dot7u"; - SQR7U_NAME = "sqr7u"; + if (caps == 2) { + DOT7U_FUNCTION_NAME = "dot7u_2"; + SQR7U_FUNCTION_NAME = "sqr7u_2"; + } else { + DOT7U_FUNCTION_NAME = "dot7u"; + SQR7U_FUNCTION_NAME = "sqr7u"; + } INSTANCE = new JdkVectorSimilarityFunctions(); } else { - DOT7U_NAME = null; - SQR7U_NAME = null; + DOT7U_FUNCTION_NAME = null; + SQR7U_FUNCTION_NAME = null; INSTANCE = null; } } catch (Throwable t) { @@ -57,9 +62,13 @@ public VectorSimilarityFunctions getVectorSimilarityFunctions() { private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions { - static final MethodHandle dot7u$mh = downcallHandle(DOT7U_NAME, FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); - static final MethodHandle sqr7u$mh = downcallHandle(SQR7U_NAME, FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + static final MethodHandle dot7u$mh; + static final MethodHandle sqr7u$mh; + static { + dot7u$mh = downcallHandle(DOT7U_FUNCTION_NAME, FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + sqr7u$mh = downcallHandle(SQR7U_FUNCTION_NAME, FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + } /** * Computes the dot product of given unsigned int7 byte vectors. * diff --git a/libs/vec/native/build.gradle b/libs/vec/native/build.gradle index ef9120680646a..1629bbc766382 100644 --- a/libs/vec/native/build.gradle +++ b/libs/vec/native/build.gradle @@ -6,6 +6,7 @@ * Side Public License, v 1. */ apply plugin: 'c' +apply plugin: 'cpp' var os = org.gradle.internal.os.OperatingSystem.current() @@ -51,6 +52,8 @@ model { target("amd64") { cCompiler.executable = "/usr/bin/gcc" cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c99", "-march=core-avx2", "-Wno-incompatible-pointer-types"]) } + cppCompiler.executable = "/usr/bin/g++" + cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=core-avx2"]) } } } cl(VisualCpp) { @@ -86,6 +89,15 @@ model { srcDir "src/vec/headers/" } } + cpp { + source { + srcDir "src/vec/c/${platformName}/" + include "*.cpp" + } + exportedHeaders { + srcDir "src/vec/headers/" + } + } } } } diff --git a/libs/vec/native/src/vec/c/amd64/vec_2.c b/libs/vec/native/src/vec/c/amd64/vec_2.cpp similarity index 51% rename from libs/vec/native/src/vec/c/amd64/vec_2.c rename to libs/vec/native/src/vec/c/amd64/vec_2.cpp index 1f3677585049f..5eb9e4f9b2989 100644 --- a/libs/vec/native/src/vec/c/amd64/vec_2.c +++ b/libs/vec/native/src/vec/c/amd64/vec_2.cpp @@ -32,34 +32,78 @@ #define SQR7U_STRIDE_BYTES_LEN sizeof(__m512i) // Must be a power of 2 #endif -static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { +// Returns acc + ( p1 * p2 ), for 64-wide int lanes. +template +inline __m512i fma8(__m512i acc, const int8_t* p1, const int8_t* p2) { + constexpr int lanes = offsetRegs * sizeof(__m512i); + const __m512i a = _mm512_loadu_si512((const __m512i*)(p1 + lanes)); + const __m512i b = _mm512_loadu_si512((const __m512i*)(p2 + lanes)); + // Perform multiplication and create 16-bit values + // Vertically multiply each unsigned 8-bit integer from a with the corresponding + // signed 8-bit integer from b, producing intermediate signed 16-bit integers. + // These values will be at max 32385, at min −32640 + const __m512i dot = _mm512_maddubs_epi16(a, b); const __m512i ones = _mm512_set1_epi16(1); + // Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results. + return _mm512_add_epi32(_mm512_madd_epi16(ones, dot), acc); + //const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(dot), _mm256_extractf128_si256(dot, 1)); + //return _mm256_add_epi32(acc, _mm256_cvtepi16_epi32(sum128)); +} + +static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { + constexpr int stride2 = 8 * DOT7U_STRIDE_BYTES_LEN; + constexpr int stride = 4 * DOT7U_STRIDE_BYTES_LEN; + const int8_t* p1 = a; + const int8_t* p2 = b; + + const ptrdiff_t rem = (( dims - 1 ) % sizeof(__m512i)) + 1; + const int8_t* const p1End = p1 + dims - rem; // Init accumulator(s) with 0 + __m512i acc0 = _mm512_setzero_si512(); __m512i acc1 = _mm512_setzero_si512(); + __m512i acc2 = _mm512_setzero_si512(); + __m512i acc3 = _mm512_setzero_si512(); + __m512i acc4 = _mm512_setzero_si512(); + __m512i acc5 = _mm512_setzero_si512(); + __m512i acc6 = _mm512_setzero_si512(); + __m512i acc7 = _mm512_setzero_si512(); + + while (p1 < p1End) { + acc0 = fma8<0>(acc0, p1, p2); + acc1 = fma8<1>(acc1, p1, p2); + acc2 = fma8<2>(acc2, p1, p2); + acc3 = fma8<3>(acc3, p1, p2); + acc4 = fma8<4>(acc4, p1, p2); + acc5 = fma8<5>(acc5, p1, p2); + acc6 = fma8<6>(acc6, p1, p2); + acc7 = fma8<7>(acc7, p1, p2); + p1 += stride2; + p2 += stride2; + } -#pragma GCC unroll 4 - for(int i = 0; i < dims; i += DOT7U_STRIDE_BYTES_LEN) { - // Load 32 packed 8-bit integers - __m512i va = _mm512_loadu_si512(a + i); - __m512i vb = _mm512_loadu_si512(b + i); - - // Perform multiplication and create 16-bit values - // Vertically multiply each unsigned 8-bit integer from va with the corresponding - // signed 8-bit integer from vb, producing intermediate signed 16-bit integers. - // These values will be at max 32385, at min −32640, - // Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results. + while (p1 < p1End) { + acc0 = fma8<0>(acc0, p1, p2); + acc1 = fma8<1>(acc1, p1, p2); + acc2 = fma8<2>(acc2, p1, p2); + acc3 = fma8<3>(acc3, p1, p2); + p1 += stride; + p2 += stride; + } - // VNNI - //acc1 = _mm512_dpbusd_epi32(acc1, va1, vb1); - const __m512i vab = _mm512_maddubs_epi16(va, vb); - acc1 = _mm512_add_epi32(_mm512_madd_epi16(ones, vab), acc1); + while (p1 < p1End) { + acc0 = fma8<0>(acc0, p1, p2); + p1 += DOT7U_STRIDE_BYTES_LEN; + p2 += DOT7U_STRIDE_BYTES_LEN; } // reduce (accumulate all) - return _mm512_reduce_add_epi32(acc1); + acc0 = _mm512_add_epi32(_mm512_add_epi32(acc0, acc1), _mm512_add_epi32(acc2, acc3)); + acc4 = _mm512_add_epi32(_mm512_add_epi32(acc4, acc5), _mm512_add_epi32(acc6, acc7)); + return _mm512_reduce_add_epi32(_mm512_add_epi32(acc0, acc4)); } +extern "C" EXPORT int32_t dot7u_2(int8_t* a, int8_t* b, size_t dims) { int32_t res = 0; int i = 0; @@ -79,7 +123,7 @@ static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { // Init accumulator(s) with 0 __m512i acc1 = _mm512_setzero_si512(); -#pragma GCC unroll 4 +#pragma GCC unroll 8 for(int i = 0; i < dims; i += SQR7U_STRIDE_BYTES_LEN) { // Load packed 8-bit integers __m512i va = _mm512_loadu_si512(a + i); @@ -98,6 +142,7 @@ static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { return _mm512_reduce_add_epi32(acc1); } +extern "C" EXPORT int32_t sqr7u_2(int8_t* a, int8_t* b, size_t dims) { int32_t res = 0; int i = 0; From 83b820ab372e3ed13ac4db16ccd932196fa5a523 Mon Sep 17 00:00:00 2001 From: Lorenzo Dematte Date: Mon, 27 May 2024 19:13:04 +0200 Subject: [PATCH 4/9] Manual unroll sqr7u + static bind mh in outer class --- .../nativeaccess/jdk/JdkVectorLibrary.java | 28 +++----- libs/vec/native/src/vec/c/amd64/vec_2.cpp | 70 +++++++++++++++---- 2 files changed, 67 insertions(+), 31 deletions(-) diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index e720cf2be008c..004d13b128a9b 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -23,8 +23,8 @@ public final class JdkVectorLibrary implements VectorLibrary { - static final String DOT7U_FUNCTION_NAME; - static final String SQR7U_FUNCTION_NAME; + static final MethodHandle dot7u$mh; + static final MethodHandle sqr7u$mh; static final VectorSimilarityFunctions INSTANCE; @@ -36,16 +36,16 @@ public final class JdkVectorLibrary implements VectorLibrary { int caps = (int) vecCaps$mh.invokeExact(); if (caps != 0) { if (caps == 2) { - DOT7U_FUNCTION_NAME = "dot7u_2"; - SQR7U_FUNCTION_NAME = "sqr7u_2"; + dot7u$mh = downcallHandle("dot7u_2", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + sqr7u$mh = downcallHandle("sqr7u_2", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); } else { - DOT7U_FUNCTION_NAME = "dot7u"; - SQR7U_FUNCTION_NAME = "sqr7u"; + dot7u$mh = downcallHandle("dot7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + sqr7u$mh = downcallHandle("sqr7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); } INSTANCE = new JdkVectorSimilarityFunctions(); } else { - DOT7U_FUNCTION_NAME = null; - SQR7U_FUNCTION_NAME = null; + dot7u$mh = null; + sqr7u$mh = null; INSTANCE = null; } } catch (Throwable t) { @@ -61,14 +61,6 @@ public VectorSimilarityFunctions getVectorSimilarityFunctions() { } private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions { - - static final MethodHandle dot7u$mh; - static final MethodHandle sqr7u$mh; - - static { - dot7u$mh = downcallHandle(DOT7U_FUNCTION_NAME, FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); - sqr7u$mh = downcallHandle(SQR7U_FUNCTION_NAME, FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); - } /** * Computes the dot product of given unsigned int7 byte vectors. * @@ -111,7 +103,7 @@ static int squareDistance7u(MemorySegment a, MemorySegment b, int length) { private static int dot7u(MemorySegment a, MemorySegment b, int length) { try { - return (int) dot7u$mh.invokeExact(a, b, length); + return (int) JdkVectorLibrary.dot7u$mh.invokeExact(a, b, length); } catch (Throwable t) { throw new AssertionError(t); } @@ -119,7 +111,7 @@ private static int dot7u(MemorySegment a, MemorySegment b, int length) { private static int sqr7u(MemorySegment a, MemorySegment b, int length) { try { - return (int) sqr7u$mh.invokeExact(a, b, length); + return (int) JdkVectorLibrary.sqr7u$mh.invokeExact(a, b, length); } catch (Throwable t) { throw new AssertionError(t); } diff --git a/libs/vec/native/src/vec/c/amd64/vec_2.cpp b/libs/vec/native/src/vec/c/amd64/vec_2.cpp index 5eb9e4f9b2989..bd1447b6ce968 100644 --- a/libs/vec/native/src/vec/c/amd64/vec_2.cpp +++ b/libs/vec/native/src/vec/c/amd64/vec_2.cpp @@ -117,29 +117,73 @@ EXPORT int32_t dot7u_2(int8_t* a, int8_t* b, size_t dims) { return res; } -static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { +template +inline __m512i sqr8(__m512i acc, const int8_t* p1, const int8_t* p2) { + constexpr int lanes = offsetRegs * sizeof(__m512i); + const __m512i a = _mm512_loadu_si512((const __m512i*)(p1 + lanes)); + const __m512i b = _mm512_loadu_si512((const __m512i*)(p2 + lanes)); + + const __m512i dist = _mm512_sub_epi8(a, b); + const __m512i abs_dist = _mm512_abs_epi8(dist); + const __m512i sqr_add = _mm512_maddubs_epi16(abs_dist, abs_dist); const __m512i ones = _mm512_set1_epi16(1); + // Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results. + return _mm512_add_epi32(_mm512_madd_epi16(ones, sqr_add), acc); + //const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(dot), _mm256_extractf128_si256(dot, 1)); + //return _mm256_add_epi32(acc, _mm256_cvtepi16_epi32(sum128)); +} + +static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { + constexpr int stride2 = 8 * SQR7U_STRIDE_BYTES_LEN; + constexpr int stride = 4 * SQR7U_STRIDE_BYTES_LEN; + const int8_t* p1 = a; + const int8_t* p2 = b; + + const ptrdiff_t rem = (( dims - 1 ) % sizeof(__m512i)) + 1; + const int8_t* const p1End = p1 + dims - rem; // Init accumulator(s) with 0 + __m512i acc0 = _mm512_setzero_si512(); __m512i acc1 = _mm512_setzero_si512(); + __m512i acc2 = _mm512_setzero_si512(); + __m512i acc3 = _mm512_setzero_si512(); + __m512i acc4 = _mm512_setzero_si512(); + __m512i acc5 = _mm512_setzero_si512(); + __m512i acc6 = _mm512_setzero_si512(); + __m512i acc7 = _mm512_setzero_si512(); -#pragma GCC unroll 8 - for(int i = 0; i < dims; i += SQR7U_STRIDE_BYTES_LEN) { - // Load packed 8-bit integers - __m512i va = _mm512_loadu_si512(a + i); - __m512i vb = _mm512_loadu_si512(b + i); + while (p1 < p1End) { + acc0 = sqr8<0>(acc0, p1, p2); + acc1 = sqr8<1>(acc1, p1, p2); + acc2 = sqr8<2>(acc2, p1, p2); + acc3 = sqr8<3>(acc3, p1, p2); + acc4 = sqr8<4>(acc4, p1, p2); + acc5 = sqr8<5>(acc5, p1, p2); + acc6 = sqr8<6>(acc6, p1, p2); + acc7 = sqr8<7>(acc7, p1, p2); + p1 += stride2; + p2 += stride2; + } - const __m512i dist = _mm512_sub_epi8(va, vb); - const __m512i abs_dist = _mm512_abs_epi8(dist); + while (p1 < p1End) { + acc0 = sqr8<0>(acc0, p1, p2); + acc1 = sqr8<1>(acc1, p1, p2); + acc2 = sqr8<2>(acc2, p1, p2); + acc3 = sqr8<3>(acc3, p1, p2); + p1 += stride; + p2 += stride; + } - // VNNI - //acc1 = _mm512_dpbusd_epi32(acc1, abs_dist, abs_dist); - const __m512i sqr = _mm512_maddubs_epi16(abs_dist, abs_dist); - acc1 = _mm512_add_epi32(_mm512_madd_epi16(ones, sqr), acc1); + while (p1 < p1End) { + acc0 = sqr8<0>(acc0, p1, p2); + p1 += SQR7U_STRIDE_BYTES_LEN; + p2 += SQR7U_STRIDE_BYTES_LEN; } // reduce (accumulate all) - return _mm512_reduce_add_epi32(acc1); + acc0 = _mm512_add_epi32(_mm512_add_epi32(acc0, acc1), _mm512_add_epi32(acc2, acc3)); + acc4 = _mm512_add_epi32(_mm512_add_epi32(acc4, acc5), _mm512_add_epi32(acc6, acc7)); + return _mm512_reduce_add_epi32(_mm512_add_epi32(acc0, acc4)); } extern "C" From ee7094b232ddee110642f869e801f7dea436b4cd Mon Sep 17 00:00:00 2001 From: Lorenzo Dematte Date: Fri, 31 May 2024 11:06:26 +0200 Subject: [PATCH 5/9] Switched compiler to clang for x64, as gcc has a bug --- .../native/{Dockerfile => Dockerfile.aarch64} | 2 +- libs/vec/native/Dockerfile.amd64 | 17 +++++ libs/vec/native/build.gradle | 5 +- libs/vec/native/publish_vec_binaries.sh | 4 +- libs/vec/native/src/vec/c/amd64/vec.c | 24 +++--- libs/vec/native/src/vec/c/amd64/vec_2.cpp | 76 +++++++++---------- 6 files changed, 67 insertions(+), 61 deletions(-) rename libs/vec/native/{Dockerfile => Dockerfile.aarch64} (75%) create mode 100644 libs/vec/native/Dockerfile.amd64 diff --git a/libs/vec/native/Dockerfile b/libs/vec/native/Dockerfile.aarch64 similarity index 75% rename from libs/vec/native/Dockerfile rename to libs/vec/native/Dockerfile.aarch64 index 607e206615ad1..66eb7e92ef479 100644 --- a/libs/vec/native/Dockerfile +++ b/libs/vec/native/Dockerfile.aarch64 @@ -4,7 +4,7 @@ RUN apt update RUN apt install -y gcc g++ openjdk-17-jdk COPY . /workspace WORKDIR /workspace -RUN ./gradlew --console=plain clean buildSharedLibrary +RUN ./gradlew --quiet --console=plain clean buildSharedLibrary RUN strip --strip-unneeded build/output/libvec.so CMD cat build/output/libvec.so diff --git a/libs/vec/native/Dockerfile.amd64 b/libs/vec/native/Dockerfile.amd64 new file mode 100644 index 0000000000000..f6e79430ed7a9 --- /dev/null +++ b/libs/vec/native/Dockerfile.amd64 @@ -0,0 +1,17 @@ +FROM debian:latest + +RUN apt update +RUN apt install -y wget +RUN echo "deb http://apt.llvm.org/bookworm/ llvm-toolchain-bookworm-18 main" > /etc/apt/sources.list.d/clang.list +RUN wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc +RUN apt update +RUN apt install -y clang-18 openjdk-17-jdk +RUN ln -s /usr/bin/clang-18 /usr/bin/clang +RUN ln -s /usr/bin/clang++-18 /usr/bin/clang++ +RUN apt install -y +COPY . /workspace +WORKDIR /workspace +RUN ./gradlew --quiet --console=plain clean buildSharedLibrary +RUN strip --strip-unneeded build/output/libvec.so + +CMD cat build/output/libvec.so diff --git a/libs/vec/native/build.gradle b/libs/vec/native/build.gradle index 1629bbc766382..073477c3aebf2 100644 --- a/libs/vec/native/build.gradle +++ b/libs/vec/native/build.gradle @@ -13,8 +13,8 @@ var os = org.gradle.internal.os.OperatingSystem.current() // To update this library run publish_vec_binaries.sh ( or ./gradlew vecSharedLibrary ) // Or // For local development, build the docker image with: -// docker build --platform linux/arm64 --progress=plain . (for aarch64) -// docker build --platform linux/amd64 --progress=plain . (for x64) +// docker build --platform linux/arm64 --progress=plain --file=Dockerfile.aarch64 . (for aarch64) +// docker build --platform linux/amd64 --progress=plain --file=Dockerfile.amd64 . (for x64) // Grab the image id from the console output, then, e.g. // docker run 9c9f36564c148b275aeecc42749e7b4580ded79dcf51ff6ccc008c8861e7a979 > build/libs/vec/shared/$arch/libvec.so // @@ -71,6 +71,7 @@ model { target("amd64") { cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c99", "-march=core-avx2"]) } + cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=core-avx2"]) } } } } diff --git a/libs/vec/native/publish_vec_binaries.sh b/libs/vec/native/publish_vec_binaries.sh index d11645ff71c4a..db836618a7bf8 100755 --- a/libs/vec/native/publish_vec_binaries.sh +++ b/libs/vec/native/publish_vec_binaries.sh @@ -33,11 +33,11 @@ echo 'Building Darwin binary...' echo 'Building Linux binary...' mkdir -p build/libs/vec/shared/aarch64/ -DOCKER_IMAGE=$(docker build --platform linux/arm64 --quiet .) +DOCKER_IMAGE=$(docker build --platform linux/arm64 --quiet --file=Dockerfile.aarch64 .) docker run $DOCKER_IMAGE > build/libs/vec/shared/aarch64/libvec.so echo 'Building Linux x64 binary...' -DOCKER_IMAGE=$(docker build --platform linux/amd64 --quiet .) +DOCKER_IMAGE=$(docker build --platform linux/amd64 --quiet --file=Dockerfile.amd64 .) mkdir -p build/libs/vec/shared/amd64 docker run --platform linux/amd64 $DOCKER_IMAGE > build/libs/vec/shared/amd64/libvec.so diff --git a/libs/vec/native/src/vec/c/amd64/vec.c b/libs/vec/native/src/vec/c/amd64/vec.c index 654009ef02f99..0fa17109fac6b 100644 --- a/libs/vec/native/src/vec/c/amd64/vec.c +++ b/libs/vec/native/src/vec/c/amd64/vec.c @@ -13,20 +13,16 @@ #include #include -#ifndef DOT7U_STRIDE_BYTES_LEN -#define DOT7U_STRIDE_BYTES_LEN sizeof(__m256i) // Must be a power of 2 -#endif - -#ifndef SQR7U_STRIDE_BYTES_LEN -#define SQR7U_STRIDE_BYTES_LEN sizeof(__m256i) // Must be a power of 2 +#ifndef STRIDE_BYTES_LEN +#define STRIDE_BYTES_LEN sizeof(__m256i) // Must be a power of 2 #endif #ifdef _MSC_VER #include -#elif __GNUC__ -#include #elif __clang__ #include +#elif __GNUC__ +#include #endif // Multi-platform CPUID "intrinsic"; it takes as input a "functionNumber" (or "leaf", the eax registry). "Subleaf" @@ -91,7 +87,7 @@ static inline int32_t dot7u_inner(int8_t* a, int8_t* b, size_t dims) { __m256i acc1 = _mm256_setzero_si256(); #pragma GCC unroll 4 - for(int i = 0; i < dims; i += DOT7U_STRIDE_BYTES_LEN) { + for(int i = 0; i < dims; i += STRIDE_BYTES_LEN) { // Load packed 8-bit integers __m256i va1 = _mm256_loadu_si256(a + i); __m256i vb1 = _mm256_loadu_si256(b + i); @@ -111,8 +107,8 @@ static inline int32_t dot7u_inner(int8_t* a, int8_t* b, size_t dims) { EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) { int32_t res = 0; int i = 0; - if (dims > DOT7U_STRIDE_BYTES_LEN) { - i += dims & ~(DOT7U_STRIDE_BYTES_LEN - 1); + if (dims > STRIDE_BYTES_LEN) { + i += dims & ~(STRIDE_BYTES_LEN - 1); res = dot7u_inner(a, b, i); } for (; i < dims; i++) { @@ -128,7 +124,7 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) { const __m256i ones = _mm256_set1_epi16(1); #pragma GCC unroll 4 - for(int i = 0; i < dims; i += SQR7U_STRIDE_BYTES_LEN) { + for(int i = 0; i < dims; i += STRIDE_BYTES_LEN) { // Load packed 8-bit integers __m256i va1 = _mm256_loadu_si256(a + i); __m256i vb1 = _mm256_loadu_si256(b + i); @@ -146,8 +142,8 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) { EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { int32_t res = 0; int i = 0; - if (dims > SQR7U_STRIDE_BYTES_LEN) { - i += dims & ~(SQR7U_STRIDE_BYTES_LEN - 1); + if (dims > STRIDE_BYTES_LEN) { + i += dims & ~(STRIDE_BYTES_LEN - 1); res = sqr7u_inner(a, b, i); } for (; i < dims; i++) { diff --git a/libs/vec/native/src/vec/c/amd64/vec_2.cpp b/libs/vec/native/src/vec/c/amd64/vec_2.cpp index bd1447b6ce968..410367d546f59 100644 --- a/libs/vec/native/src/vec/c/amd64/vec_2.cpp +++ b/libs/vec/native/src/vec/c/amd64/vec_2.cpp @@ -12,30 +12,26 @@ #ifdef _MSC_VER #include +#elif __clang__ +#pragma clang attribute push(__attribute__((target("arch=skylake-avx512"))), apply_to=function) +#include #elif __GNUC__ #pragma GCC push_options #pragma GCC target ("arch=skylake-avx512") #include -#elif __clang__ -#pragma clang attribute push (__attribute__((target("arch=skylake-avx512"))), apply_to=function) -#include #endif #include #include -#ifndef DOT7U_STRIDE_BYTES_LEN -#define DOT7U_STRIDE_BYTES_LEN sizeof(__m512i) // Must be a power of 2 -#endif - -#ifndef SQR7U_STRIDE_BYTES_LEN -#define SQR7U_STRIDE_BYTES_LEN sizeof(__m512i) // Must be a power of 2 +#ifndef STRIDE_BYTES_LEN +#define STRIDE_BYTES_LEN sizeof(__m512i) // Must be a power of 2 #endif // Returns acc + ( p1 * p2 ), for 64-wide int lanes. template inline __m512i fma8(__m512i acc, const int8_t* p1, const int8_t* p2) { - constexpr int lanes = offsetRegs * sizeof(__m512i); + constexpr int lanes = offsetRegs * STRIDE_BYTES_LEN; const __m512i a = _mm512_loadu_si512((const __m512i*)(p1 + lanes)); const __m512i b = _mm512_loadu_si512((const __m512i*)(p2 + lanes)); // Perform multiplication and create 16-bit values @@ -44,19 +40,18 @@ inline __m512i fma8(__m512i acc, const int8_t* p1, const int8_t* p2) { // These values will be at max 32385, at min −32640 const __m512i dot = _mm512_maddubs_epi16(a, b); const __m512i ones = _mm512_set1_epi16(1); - // Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results. + // Horizontally add adjacent pairs of intermediate signed 16-bit ints, and pack the results in 32-bit ints. + // Using madd with 1, as this is faster than extract 2 halves, add 16-bit ints, and convert to 32-bit ints. return _mm512_add_epi32(_mm512_madd_epi16(ones, dot), acc); - //const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(dot), _mm256_extractf128_si256(dot, 1)); - //return _mm256_add_epi32(acc, _mm256_cvtepi16_epi32(sum128)); } static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { - constexpr int stride2 = 8 * DOT7U_STRIDE_BYTES_LEN; - constexpr int stride = 4 * DOT7U_STRIDE_BYTES_LEN; + constexpr int stride8 = 8 * STRIDE_BYTES_LEN; + constexpr int stride4 = 4 * STRIDE_BYTES_LEN; const int8_t* p1 = a; const int8_t* p2 = b; - const ptrdiff_t rem = (( dims - 1 ) % sizeof(__m512i)) + 1; + const ptrdiff_t rem = (( dims - 1 ) % STRIDE_BYTES_LEN) + 1; const int8_t* const p1End = p1 + dims - rem; // Init accumulator(s) with 0 @@ -78,8 +73,8 @@ static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { acc5 = fma8<5>(acc5, p1, p2); acc6 = fma8<6>(acc6, p1, p2); acc7 = fma8<7>(acc7, p1, p2); - p1 += stride2; - p2 += stride2; + p1 += stride8; + p2 += stride8; } while (p1 < p1End) { @@ -87,14 +82,14 @@ static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { acc1 = fma8<1>(acc1, p1, p2); acc2 = fma8<2>(acc2, p1, p2); acc3 = fma8<3>(acc3, p1, p2); - p1 += stride; - p2 += stride; + p1 += stride4; + p2 += stride4; } while (p1 < p1End) { acc0 = fma8<0>(acc0, p1, p2); - p1 += DOT7U_STRIDE_BYTES_LEN; - p2 += DOT7U_STRIDE_BYTES_LEN; + p1 += STRIDE_BYTES_LEN; + p2 += STRIDE_BYTES_LEN; } // reduce (accumulate all) @@ -107,8 +102,8 @@ extern "C" EXPORT int32_t dot7u_2(int8_t* a, int8_t* b, size_t dims) { int32_t res = 0; int i = 0; - if (dims > DOT7U_STRIDE_BYTES_LEN) { - i += dims & ~(DOT7U_STRIDE_BYTES_LEN - 1); + if (dims > STRIDE_BYTES_LEN) { + i += dims & ~(STRIDE_BYTES_LEN - 1); res = dot7u_inner_avx512(a, b, i); } for (; i < dims; i++) { @@ -119,7 +114,7 @@ EXPORT int32_t dot7u_2(int8_t* a, int8_t* b, size_t dims) { template inline __m512i sqr8(__m512i acc, const int8_t* p1, const int8_t* p2) { - constexpr int lanes = offsetRegs * sizeof(__m512i); + constexpr int lanes = offsetRegs * STRIDE_BYTES_LEN; const __m512i a = _mm512_loadu_si512((const __m512i*)(p1 + lanes)); const __m512i b = _mm512_loadu_si512((const __m512i*)(p2 + lanes)); @@ -129,17 +124,15 @@ inline __m512i sqr8(__m512i acc, const int8_t* p1, const int8_t* p2) { const __m512i ones = _mm512_set1_epi16(1); // Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results. return _mm512_add_epi32(_mm512_madd_epi16(ones, sqr_add), acc); - //const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(dot), _mm256_extractf128_si256(dot, 1)); - //return _mm256_add_epi32(acc, _mm256_cvtepi16_epi32(sum128)); } static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { - constexpr int stride2 = 8 * SQR7U_STRIDE_BYTES_LEN; - constexpr int stride = 4 * SQR7U_STRIDE_BYTES_LEN; + constexpr int stride8 = 8 * STRIDE_BYTES_LEN; + constexpr int stride4 = 4 * STRIDE_BYTES_LEN; const int8_t* p1 = a; const int8_t* p2 = b; - const ptrdiff_t rem = (( dims - 1 ) % sizeof(__m512i)) + 1; + const ptrdiff_t rem = (( dims - 1 ) % STRIDE_BYTES_LEN) + 1; const int8_t* const p1End = p1 + dims - rem; // Init accumulator(s) with 0 @@ -161,8 +154,8 @@ static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { acc5 = sqr8<5>(acc5, p1, p2); acc6 = sqr8<6>(acc6, p1, p2); acc7 = sqr8<7>(acc7, p1, p2); - p1 += stride2; - p2 += stride2; + p1 += stride8; + p2 += stride8; } while (p1 < p1End) { @@ -170,14 +163,14 @@ static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { acc1 = sqr8<1>(acc1, p1, p2); acc2 = sqr8<2>(acc2, p1, p2); acc3 = sqr8<3>(acc3, p1, p2); - p1 += stride; - p2 += stride; + p1 += stride4; + p2 += stride4; } while (p1 < p1End) { acc0 = sqr8<0>(acc0, p1, p2); - p1 += SQR7U_STRIDE_BYTES_LEN; - p2 += SQR7U_STRIDE_BYTES_LEN; + p1 += STRIDE_BYTES_LEN; + p2 += STRIDE_BYTES_LEN; } // reduce (accumulate all) @@ -190,8 +183,8 @@ extern "C" EXPORT int32_t sqr7u_2(int8_t* a, int8_t* b, size_t dims) { int32_t res = 0; int i = 0; - if (dims > SQR7U_STRIDE_BYTES_LEN) { - i += dims & ~(SQR7U_STRIDE_BYTES_LEN - 1); + if (dims > STRIDE_BYTES_LEN) { + i += dims & ~(STRIDE_BYTES_LEN - 1); res = sqr7u_inner_avx512(a, b, i); } for (; i < dims; i++) { @@ -201,9 +194,8 @@ EXPORT int32_t sqr7u_2(int8_t* a, int8_t* b, size_t dims) { return res; } -#ifdef __GNUC__ -#pragma GCC pop_options -#elif __clang__ +#ifdef __clang__ #pragma clang attribute pop +#elif __GNUC__ +#pragma GCC pop_options #endif - From f59b38329355beea70b277b522b78413584d7120 Mon Sep 17 00:00:00 2001 From: Lorenzo Dematte Date: Thu, 27 Jun 2024 12:59:50 +0200 Subject: [PATCH 6/9] Update native simdvec library version --- libs/native/libraries/build.gradle | 2 +- libs/simdvec/native/publish_vec_binaries.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index b7e6a1c704e6e..8f1a12055bd7e 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -18,7 +18,7 @@ configurations { } var zstdVersion = "1.5.5" -var vecVersion = "1.0.9" +var vecVersion = "1.0.10" repositories { exclusiveContent { diff --git a/libs/simdvec/native/publish_vec_binaries.sh b/libs/simdvec/native/publish_vec_binaries.sh index db836618a7bf8..ddb3d2c71e448 100755 --- a/libs/simdvec/native/publish_vec_binaries.sh +++ b/libs/simdvec/native/publish_vec_binaries.sh @@ -19,7 +19,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then exit 1; fi -VERSION="1.0.9" +VERSION="1.0.10" ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}" TEMP=$(mktemp -d) From 3e90ff8ad69f17b8fc7ebdc33c82741bbb8c58e1 Mon Sep 17 00:00:00 2001 From: Lorenzo Dematte Date: Thu, 27 Jun 2024 16:53:02 +0200 Subject: [PATCH 7/9] Small fix to Dockerfile.amd64 --- libs/simdvec/native/Dockerfile.amd64 | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/simdvec/native/Dockerfile.amd64 b/libs/simdvec/native/Dockerfile.amd64 index f6e79430ed7a9..77acf8e42cdd2 100644 --- a/libs/simdvec/native/Dockerfile.amd64 +++ b/libs/simdvec/native/Dockerfile.amd64 @@ -8,7 +8,6 @@ RUN apt update RUN apt install -y clang-18 openjdk-17-jdk RUN ln -s /usr/bin/clang-18 /usr/bin/clang RUN ln -s /usr/bin/clang++-18 /usr/bin/clang++ -RUN apt install -y COPY . /workspace WORKDIR /workspace RUN ./gradlew --quiet --console=plain clean buildSharedLibrary From 584b54d0f280ad0bd5c1815b28b96ac0426880fc Mon Sep 17 00:00:00 2001 From: Lorenzo Dematte Date: Thu, 27 Jun 2024 18:54:34 +0200 Subject: [PATCH 8/9] Fix end boundaries for non-power of 2 dims --- libs/simdvec/native/src/vec/c/amd64/vec_2.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp b/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp index 410367d546f59..1606b31907405 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp +++ b/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp @@ -51,9 +51,6 @@ static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { const int8_t* p1 = a; const int8_t* p2 = b; - const ptrdiff_t rem = (( dims - 1 ) % STRIDE_BYTES_LEN) + 1; - const int8_t* const p1End = p1 + dims - rem; - // Init accumulator(s) with 0 __m512i acc0 = _mm512_setzero_si512(); __m512i acc1 = _mm512_setzero_si512(); @@ -64,6 +61,7 @@ static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { __m512i acc6 = _mm512_setzero_si512(); __m512i acc7 = _mm512_setzero_si512(); + const int8_t* p1End = a + (dims & ~(stride8 - 1)); while (p1 < p1End) { acc0 = fma8<0>(acc0, p1, p2); acc1 = fma8<1>(acc1, p1, p2); @@ -77,6 +75,7 @@ static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { p2 += stride8; } + p1End = a + (dims & ~(stride4 - 1)); while (p1 < p1End) { acc0 = fma8<0>(acc0, p1, p2); acc1 = fma8<1>(acc1, p1, p2); @@ -86,6 +85,7 @@ static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { p2 += stride4; } + p1End = a + (dims & ~(STRIDE_BYTES_LEN - 1)); while (p1 < p1End) { acc0 = fma8<0>(acc0, p1, p2); p1 += STRIDE_BYTES_LEN; @@ -132,9 +132,6 @@ static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { const int8_t* p1 = a; const int8_t* p2 = b; - const ptrdiff_t rem = (( dims - 1 ) % STRIDE_BYTES_LEN) + 1; - const int8_t* const p1End = p1 + dims - rem; - // Init accumulator(s) with 0 __m512i acc0 = _mm512_setzero_si512(); __m512i acc1 = _mm512_setzero_si512(); @@ -145,6 +142,7 @@ static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { __m512i acc6 = _mm512_setzero_si512(); __m512i acc7 = _mm512_setzero_si512(); + const int8_t* p1End = a + (dims & ~(stride8 - 1)); while (p1 < p1End) { acc0 = sqr8<0>(acc0, p1, p2); acc1 = sqr8<1>(acc1, p1, p2); @@ -158,6 +156,7 @@ static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { p2 += stride8; } + p1End = a + (dims & ~(stride4 - 1)); while (p1 < p1End) { acc0 = sqr8<0>(acc0, p1, p2); acc1 = sqr8<1>(acc1, p1, p2); @@ -167,6 +166,7 @@ static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { p2 += stride4; } + p1End = a + (dims & ~(STRIDE_BYTES_LEN - 1)); while (p1 < p1End) { acc0 = sqr8<0>(acc0, p1, p2); p1 += STRIDE_BYTES_LEN; From aced309ae303788f1f12efdc1e8906119b27cf1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lorenzo=20Dematt=C3=A9?= Date: Thu, 27 Jun 2024 19:16:06 +0200 Subject: [PATCH 9/9] Update docs/changelog/109084.yaml --- docs/changelog/109084.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/109084.yaml diff --git a/docs/changelog/109084.yaml b/docs/changelog/109084.yaml new file mode 100644 index 0000000000000..67ff5610c5a66 --- /dev/null +++ b/docs/changelog/109084.yaml @@ -0,0 +1,5 @@ +pr: 109084 +summary: Add AVX-512 optimised vector distance functions for int7 on x64 +area: Search +type: enhancement +issues: []