diff --git a/docs/changelog/109084.yaml b/docs/changelog/109084.yaml new file mode 100644 index 0000000000000..67ff5610c5a66 --- /dev/null +++ b/docs/changelog/109084.yaml @@ -0,0 +1,5 @@ +pr: 109084 +summary: Add AVX-512 optimised vector distance functions for int7 on x64 +area: Search +type: enhancement +issues: [] diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index b7e6a1c704e6e..8f1a12055bd7e 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -18,7 +18,7 @@ configurations { } var zstdVersion = "1.5.5" -var vecVersion = "1.0.9" +var vecVersion = "1.0.10" repositories { exclusiveContent { diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index db2e7b85c30d0..c92ad654c9b9a 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -23,6 +23,9 @@ public final class JdkVectorLibrary implements VectorLibrary { + static final MethodHandle dot7u$mh; + static final MethodHandle sqr7u$mh; + static final VectorSimilarityFunctions INSTANCE; static { @@ -32,8 +35,33 @@ public final class JdkVectorLibrary implements VectorLibrary { try { int caps = (int) vecCaps$mh.invokeExact(); if (caps != 0) { + if (caps == 2) { + dot7u$mh = downcallHandle( + "dot7u_2", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + sqr7u$mh = downcallHandle( + "sqr7u_2", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + } else { + dot7u$mh = downcallHandle( + "dot7u", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + sqr7u$mh = downcallHandle( + "sqr7u", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + } INSTANCE = new JdkVectorSimilarityFunctions(); } else { + dot7u$mh = null; + sqr7u$mh = null; INSTANCE = null; } } catch (Throwable t) { @@ -49,18 +77,6 @@ public VectorSimilarityFunctions getVectorSimilarityFunctions() { } private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions { - - static final MethodHandle dot7u$mh = downcallHandle( - "dot7u", - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), - LinkerHelperUtil.critical() - ); - static final MethodHandle sqr7u$mh = downcallHandle( - "sqr7u", - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), - LinkerHelperUtil.critical() - ); - /** * Computes the dot product of given unsigned int7 byte vectors. * @@ -103,7 +119,7 @@ static int squareDistance7u(MemorySegment a, MemorySegment b, int length) { private static int dot7u(MemorySegment a, MemorySegment b, int length) { try { - return (int) dot7u$mh.invokeExact(a, b, length); + return (int) JdkVectorLibrary.dot7u$mh.invokeExact(a, b, length); } catch (Throwable t) { throw new AssertionError(t); } @@ -111,7 +127,7 @@ private static int dot7u(MemorySegment a, MemorySegment b, int length) { private static int sqr7u(MemorySegment a, MemorySegment b, int length) { try { - return (int) sqr7u$mh.invokeExact(a, b, length); + return (int) JdkVectorLibrary.sqr7u$mh.invokeExact(a, b, length); } catch (Throwable t) { throw new AssertionError(t); } diff --git a/libs/simdvec/native/Dockerfile b/libs/simdvec/native/Dockerfile.aarch64 similarity index 100% rename from libs/simdvec/native/Dockerfile rename to libs/simdvec/native/Dockerfile.aarch64 diff --git a/libs/simdvec/native/Dockerfile.amd64 b/libs/simdvec/native/Dockerfile.amd64 new file mode 100644 index 0000000000000..77acf8e42cdd2 --- /dev/null +++ b/libs/simdvec/native/Dockerfile.amd64 @@ -0,0 +1,16 @@ +FROM debian:latest + +RUN apt update +RUN apt install -y wget +RUN echo "deb http://apt.llvm.org/bookworm/ llvm-toolchain-bookworm-18 main" > /etc/apt/sources.list.d/clang.list +RUN wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc +RUN apt update +RUN apt install -y clang-18 openjdk-17-jdk +RUN ln -s /usr/bin/clang-18 /usr/bin/clang +RUN ln -s /usr/bin/clang++-18 /usr/bin/clang++ +COPY . /workspace +WORKDIR /workspace +RUN ./gradlew --quiet --console=plain clean buildSharedLibrary +RUN strip --strip-unneeded build/output/libvec.so + +CMD cat build/output/libvec.so diff --git a/libs/simdvec/native/build.gradle b/libs/simdvec/native/build.gradle index ef9120680646a..073477c3aebf2 100644 --- a/libs/simdvec/native/build.gradle +++ b/libs/simdvec/native/build.gradle @@ -6,14 +6,15 @@ * Side Public License, v 1. */ apply plugin: 'c' +apply plugin: 'cpp' var os = org.gradle.internal.os.OperatingSystem.current() // To update this library run publish_vec_binaries.sh ( or ./gradlew vecSharedLibrary ) // Or // For local development, build the docker image with: -// docker build --platform linux/arm64 --progress=plain . (for aarch64) -// docker build --platform linux/amd64 --progress=plain . (for x64) +// docker build --platform linux/arm64 --progress=plain --file=Dockerfile.aarch64 . (for aarch64) +// docker build --platform linux/amd64 --progress=plain --file=Dockerfile.amd64 . (for x64) // Grab the image id from the console output, then, e.g. // docker run 9c9f36564c148b275aeecc42749e7b4580ded79dcf51ff6ccc008c8861e7a979 > build/libs/vec/shared/$arch/libvec.so // @@ -51,6 +52,8 @@ model { target("amd64") { cCompiler.executable = "/usr/bin/gcc" cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c99", "-march=core-avx2", "-Wno-incompatible-pointer-types"]) } + cppCompiler.executable = "/usr/bin/g++" + cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=core-avx2"]) } } } cl(VisualCpp) { @@ -68,6 +71,7 @@ model { target("amd64") { cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c99", "-march=core-avx2"]) } + cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=core-avx2"]) } } } } @@ -86,6 +90,15 @@ model { srcDir "src/vec/headers/" } } + cpp { + source { + srcDir "src/vec/c/${platformName}/" + include "*.cpp" + } + exportedHeaders { + srcDir "src/vec/headers/" + } + } } } } diff --git a/libs/simdvec/native/publish_vec_binaries.sh b/libs/simdvec/native/publish_vec_binaries.sh index d11645ff71c4a..ddb3d2c71e448 100755 --- a/libs/simdvec/native/publish_vec_binaries.sh +++ b/libs/simdvec/native/publish_vec_binaries.sh @@ -19,7 +19,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then exit 1; fi -VERSION="1.0.9" +VERSION="1.0.10" ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}" TEMP=$(mktemp -d) @@ -33,11 +33,11 @@ echo 'Building Darwin binary...' echo 'Building Linux binary...' mkdir -p build/libs/vec/shared/aarch64/ -DOCKER_IMAGE=$(docker build --platform linux/arm64 --quiet .) +DOCKER_IMAGE=$(docker build --platform linux/arm64 --quiet --file=Dockerfile.aarch64 .) docker run $DOCKER_IMAGE > build/libs/vec/shared/aarch64/libvec.so echo 'Building Linux x64 binary...' -DOCKER_IMAGE=$(docker build --platform linux/amd64 --quiet .) +DOCKER_IMAGE=$(docker build --platform linux/amd64 --quiet --file=Dockerfile.amd64 .) mkdir -p build/libs/vec/shared/amd64 docker run --platform linux/amd64 $DOCKER_IMAGE > build/libs/vec/shared/amd64/libvec.so diff --git a/libs/simdvec/native/src/vec/c/amd64/vec.c b/libs/simdvec/native/src/vec/c/amd64/vec.c index c9a49ad2d1d4d..0fa17109fac6b 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec.c +++ b/libs/simdvec/native/src/vec/c/amd64/vec.c @@ -13,20 +13,16 @@ #include #include -#ifndef DOT7U_STRIDE_BYTES_LEN -#define DOT7U_STRIDE_BYTES_LEN 32 // Must be a power of 2 -#endif - -#ifndef SQR7U_STRIDE_BYTES_LEN -#define SQR7U_STRIDE_BYTES_LEN 32 // Must be a power of 2 +#ifndef STRIDE_BYTES_LEN +#define STRIDE_BYTES_LEN sizeof(__m256i) // Must be a power of 2 #endif #ifdef _MSC_VER #include -#elif __GNUC__ -#include #elif __clang__ #include +#elif __GNUC__ +#include #endif // Multi-platform CPUID "intrinsic"; it takes as input a "functionNumber" (or "leaf", the eax registry). "Subleaf" @@ -67,9 +63,19 @@ EXPORT int vec_caps() { if (functionIds >= 7) { cpuid(cpuInfo, 7); int ebx = cpuInfo[1]; + int ecx = cpuInfo[2]; // AVX2 flag is the 5th bit // We assume that all processors that have AVX2 also have FMA3 - return (ebx & (1 << 5)) != 0; + int avx2 = (ebx & 0x00000020) != 0; + int avx512 = (ebx & 0x00010000) != 0; + // int avx512_vnni = (ecx & 0x00000800) != 0; + // if (avx512 && avx512_vnni) { + if (avx512) { + return 2; + } + if (avx2) { + return 1; + } } return 0; } @@ -81,7 +87,7 @@ static inline int32_t dot7u_inner(int8_t* a, int8_t* b, size_t dims) { __m256i acc1 = _mm256_setzero_si256(); #pragma GCC unroll 4 - for(int i = 0; i < dims; i += DOT7U_STRIDE_BYTES_LEN) { + for(int i = 0; i < dims; i += STRIDE_BYTES_LEN) { // Load packed 8-bit integers __m256i va1 = _mm256_loadu_si256(a + i); __m256i vb1 = _mm256_loadu_si256(b + i); @@ -101,8 +107,8 @@ static inline int32_t dot7u_inner(int8_t* a, int8_t* b, size_t dims) { EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) { int32_t res = 0; int i = 0; - if (dims > DOT7U_STRIDE_BYTES_LEN) { - i += dims & ~(DOT7U_STRIDE_BYTES_LEN - 1); + if (dims > STRIDE_BYTES_LEN) { + i += dims & ~(STRIDE_BYTES_LEN - 1); res = dot7u_inner(a, b, i); } for (; i < dims; i++) { @@ -118,7 +124,7 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) { const __m256i ones = _mm256_set1_epi16(1); #pragma GCC unroll 4 - for(int i = 0; i < dims; i += SQR7U_STRIDE_BYTES_LEN) { + for(int i = 0; i < dims; i += STRIDE_BYTES_LEN) { // Load packed 8-bit integers __m256i va1 = _mm256_loadu_si256(a + i); __m256i vb1 = _mm256_loadu_si256(b + i); @@ -126,7 +132,6 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) { const __m256i dist1 = _mm256_sub_epi8(va1, vb1); const __m256i abs_dist1 = _mm256_sign_epi8(dist1, dist1); const __m256i sqr1 = _mm256_maddubs_epi16(abs_dist1, abs_dist1); - acc1 = _mm256_add_epi32(_mm256_madd_epi16(ones, sqr1), acc1); } @@ -137,8 +142,8 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) { EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { int32_t res = 0; int i = 0; - if (dims > SQR7U_STRIDE_BYTES_LEN) { - i += dims & ~(SQR7U_STRIDE_BYTES_LEN - 1); + if (dims > STRIDE_BYTES_LEN) { + i += dims & ~(STRIDE_BYTES_LEN - 1); res = sqr7u_inner(a, b, i); } for (; i < dims; i++) { @@ -147,4 +152,3 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { } return res; } - diff --git a/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp b/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp new file mode 100644 index 0000000000000..1606b31907405 --- /dev/null +++ b/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp @@ -0,0 +1,201 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +#include +#include +#include "vec.h" + +#ifdef _MSC_VER +#include +#elif __clang__ +#pragma clang attribute push(__attribute__((target("arch=skylake-avx512"))), apply_to=function) +#include +#elif __GNUC__ +#pragma GCC push_options +#pragma GCC target ("arch=skylake-avx512") +#include +#endif + +#include +#include + +#ifndef STRIDE_BYTES_LEN +#define STRIDE_BYTES_LEN sizeof(__m512i) // Must be a power of 2 +#endif + +// Returns acc + ( p1 * p2 ), for 64-wide int lanes. +template +inline __m512i fma8(__m512i acc, const int8_t* p1, const int8_t* p2) { + constexpr int lanes = offsetRegs * STRIDE_BYTES_LEN; + const __m512i a = _mm512_loadu_si512((const __m512i*)(p1 + lanes)); + const __m512i b = _mm512_loadu_si512((const __m512i*)(p2 + lanes)); + // Perform multiplication and create 16-bit values + // Vertically multiply each unsigned 8-bit integer from a with the corresponding + // signed 8-bit integer from b, producing intermediate signed 16-bit integers. + // These values will be at max 32385, at min −32640 + const __m512i dot = _mm512_maddubs_epi16(a, b); + const __m512i ones = _mm512_set1_epi16(1); + // Horizontally add adjacent pairs of intermediate signed 16-bit ints, and pack the results in 32-bit ints. + // Using madd with 1, as this is faster than extract 2 halves, add 16-bit ints, and convert to 32-bit ints. + return _mm512_add_epi32(_mm512_madd_epi16(ones, dot), acc); +} + +static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) { + constexpr int stride8 = 8 * STRIDE_BYTES_LEN; + constexpr int stride4 = 4 * STRIDE_BYTES_LEN; + const int8_t* p1 = a; + const int8_t* p2 = b; + + // Init accumulator(s) with 0 + __m512i acc0 = _mm512_setzero_si512(); + __m512i acc1 = _mm512_setzero_si512(); + __m512i acc2 = _mm512_setzero_si512(); + __m512i acc3 = _mm512_setzero_si512(); + __m512i acc4 = _mm512_setzero_si512(); + __m512i acc5 = _mm512_setzero_si512(); + __m512i acc6 = _mm512_setzero_si512(); + __m512i acc7 = _mm512_setzero_si512(); + + const int8_t* p1End = a + (dims & ~(stride8 - 1)); + while (p1 < p1End) { + acc0 = fma8<0>(acc0, p1, p2); + acc1 = fma8<1>(acc1, p1, p2); + acc2 = fma8<2>(acc2, p1, p2); + acc3 = fma8<3>(acc3, p1, p2); + acc4 = fma8<4>(acc4, p1, p2); + acc5 = fma8<5>(acc5, p1, p2); + acc6 = fma8<6>(acc6, p1, p2); + acc7 = fma8<7>(acc7, p1, p2); + p1 += stride8; + p2 += stride8; + } + + p1End = a + (dims & ~(stride4 - 1)); + while (p1 < p1End) { + acc0 = fma8<0>(acc0, p1, p2); + acc1 = fma8<1>(acc1, p1, p2); + acc2 = fma8<2>(acc2, p1, p2); + acc3 = fma8<3>(acc3, p1, p2); + p1 += stride4; + p2 += stride4; + } + + p1End = a + (dims & ~(STRIDE_BYTES_LEN - 1)); + while (p1 < p1End) { + acc0 = fma8<0>(acc0, p1, p2); + p1 += STRIDE_BYTES_LEN; + p2 += STRIDE_BYTES_LEN; + } + + // reduce (accumulate all) + acc0 = _mm512_add_epi32(_mm512_add_epi32(acc0, acc1), _mm512_add_epi32(acc2, acc3)); + acc4 = _mm512_add_epi32(_mm512_add_epi32(acc4, acc5), _mm512_add_epi32(acc6, acc7)); + return _mm512_reduce_add_epi32(_mm512_add_epi32(acc0, acc4)); +} + +extern "C" +EXPORT int32_t dot7u_2(int8_t* a, int8_t* b, size_t dims) { + int32_t res = 0; + int i = 0; + if (dims > STRIDE_BYTES_LEN) { + i += dims & ~(STRIDE_BYTES_LEN - 1); + res = dot7u_inner_avx512(a, b, i); + } + for (; i < dims; i++) { + res += a[i] * b[i]; + } + return res; +} + +template +inline __m512i sqr8(__m512i acc, const int8_t* p1, const int8_t* p2) { + constexpr int lanes = offsetRegs * STRIDE_BYTES_LEN; + const __m512i a = _mm512_loadu_si512((const __m512i*)(p1 + lanes)); + const __m512i b = _mm512_loadu_si512((const __m512i*)(p2 + lanes)); + + const __m512i dist = _mm512_sub_epi8(a, b); + const __m512i abs_dist = _mm512_abs_epi8(dist); + const __m512i sqr_add = _mm512_maddubs_epi16(abs_dist, abs_dist); + const __m512i ones = _mm512_set1_epi16(1); + // Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results. + return _mm512_add_epi32(_mm512_madd_epi16(ones, sqr_add), acc); +} + +static inline int32_t sqr7u_inner_avx512(int8_t *a, int8_t *b, size_t dims) { + constexpr int stride8 = 8 * STRIDE_BYTES_LEN; + constexpr int stride4 = 4 * STRIDE_BYTES_LEN; + const int8_t* p1 = a; + const int8_t* p2 = b; + + // Init accumulator(s) with 0 + __m512i acc0 = _mm512_setzero_si512(); + __m512i acc1 = _mm512_setzero_si512(); + __m512i acc2 = _mm512_setzero_si512(); + __m512i acc3 = _mm512_setzero_si512(); + __m512i acc4 = _mm512_setzero_si512(); + __m512i acc5 = _mm512_setzero_si512(); + __m512i acc6 = _mm512_setzero_si512(); + __m512i acc7 = _mm512_setzero_si512(); + + const int8_t* p1End = a + (dims & ~(stride8 - 1)); + while (p1 < p1End) { + acc0 = sqr8<0>(acc0, p1, p2); + acc1 = sqr8<1>(acc1, p1, p2); + acc2 = sqr8<2>(acc2, p1, p2); + acc3 = sqr8<3>(acc3, p1, p2); + acc4 = sqr8<4>(acc4, p1, p2); + acc5 = sqr8<5>(acc5, p1, p2); + acc6 = sqr8<6>(acc6, p1, p2); + acc7 = sqr8<7>(acc7, p1, p2); + p1 += stride8; + p2 += stride8; + } + + p1End = a + (dims & ~(stride4 - 1)); + while (p1 < p1End) { + acc0 = sqr8<0>(acc0, p1, p2); + acc1 = sqr8<1>(acc1, p1, p2); + acc2 = sqr8<2>(acc2, p1, p2); + acc3 = sqr8<3>(acc3, p1, p2); + p1 += stride4; + p2 += stride4; + } + + p1End = a + (dims & ~(STRIDE_BYTES_LEN - 1)); + while (p1 < p1End) { + acc0 = sqr8<0>(acc0, p1, p2); + p1 += STRIDE_BYTES_LEN; + p2 += STRIDE_BYTES_LEN; + } + + // reduce (accumulate all) + acc0 = _mm512_add_epi32(_mm512_add_epi32(acc0, acc1), _mm512_add_epi32(acc2, acc3)); + acc4 = _mm512_add_epi32(_mm512_add_epi32(acc4, acc5), _mm512_add_epi32(acc6, acc7)); + return _mm512_reduce_add_epi32(_mm512_add_epi32(acc0, acc4)); +} + +extern "C" +EXPORT int32_t sqr7u_2(int8_t* a, int8_t* b, size_t dims) { + int32_t res = 0; + int i = 0; + if (dims > STRIDE_BYTES_LEN) { + i += dims & ~(STRIDE_BYTES_LEN - 1); + res = sqr7u_inner_avx512(a, b, i); + } + for (; i < dims; i++) { + int32_t dist = a[i] - b[i]; + res += dist * dist; + } + return res; +} + +#ifdef __clang__ +#pragma clang attribute pop +#elif __GNUC__ +#pragma GCC pop_options +#endif