Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/109084.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 109084
summary: Add AVX-512 optimised vector distance functions for int7 on x64
area: Search
type: enhancement
issues: []
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ configurations {
}

var zstdVersion = "1.5.5"
var vecVersion = "1.0.9"
var vecVersion = "1.0.10"

repositories {
exclusiveContent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

public final class JdkVectorLibrary implements VectorLibrary {

static final MethodHandle dot7u$mh;
static final MethodHandle sqr7u$mh;

static final VectorSimilarityFunctions INSTANCE;

static {
Expand All @@ -32,8 +35,33 @@ public final class JdkVectorLibrary implements VectorLibrary {
try {
int caps = (int) vecCaps$mh.invokeExact();
if (caps != 0) {
if (caps == 2) {
dot7u$mh = downcallHandle(
"dot7u_2",
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
LinkerHelperUtil.critical()
);
sqr7u$mh = downcallHandle(
"sqr7u_2",
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
LinkerHelperUtil.critical()
);
} else {
dot7u$mh = downcallHandle(
"dot7u",
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
LinkerHelperUtil.critical()
);
sqr7u$mh = downcallHandle(
"sqr7u",
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
LinkerHelperUtil.critical()
);
}
INSTANCE = new JdkVectorSimilarityFunctions();
} else {
dot7u$mh = null;
sqr7u$mh = null;
INSTANCE = null;
}
} catch (Throwable t) {
Expand All @@ -49,18 +77,6 @@ public VectorSimilarityFunctions getVectorSimilarityFunctions() {
}

private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions {

static final MethodHandle dot7u$mh = downcallHandle(
"dot7u",
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
LinkerHelperUtil.critical()
);
static final MethodHandle sqr7u$mh = downcallHandle(
"sqr7u",
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
LinkerHelperUtil.critical()
);

/**
* Computes the dot product of given unsigned int7 byte vectors.
*
Expand Down Expand Up @@ -103,15 +119,15 @@ static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {

private static int dot7u(MemorySegment a, MemorySegment b, int length) {
try {
return (int) dot7u$mh.invokeExact(a, b, length);
return (int) JdkVectorLibrary.dot7u$mh.invokeExact(a, b, length);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private static int sqr7u(MemorySegment a, MemorySegment b, int length) {
try {
return (int) sqr7u$mh.invokeExact(a, b, length);
return (int) JdkVectorLibrary.sqr7u$mh.invokeExact(a, b, length);
} catch (Throwable t) {
throw new AssertionError(t);
}
Expand Down
File renamed without changes.
16 changes: 16 additions & 0 deletions libs/simdvec/native/Dockerfile.amd64
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
FROM debian:latest

RUN apt update
RUN apt install -y wget
RUN echo "deb http://apt.llvm.org/bookworm/ llvm-toolchain-bookworm-18 main" > /etc/apt/sources.list.d/clang.list
RUN wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc
RUN apt update
RUN apt install -y clang-18 openjdk-17-jdk
RUN ln -s /usr/bin/clang-18 /usr/bin/clang
RUN ln -s /usr/bin/clang++-18 /usr/bin/clang++
COPY . /workspace
WORKDIR /workspace
RUN ./gradlew --quiet --console=plain clean buildSharedLibrary
RUN strip --strip-unneeded build/output/libvec.so

CMD cat build/output/libvec.so
17 changes: 15 additions & 2 deletions libs/simdvec/native/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
* Side Public License, v 1.
*/
apply plugin: 'c'
apply plugin: 'cpp'

var os = org.gradle.internal.os.OperatingSystem.current()

// To update this library run publish_vec_binaries.sh ( or ./gradlew vecSharedLibrary )
// Or
// For local development, build the docker image with:
// docker build --platform linux/arm64 --progress=plain . (for aarch64)
// docker build --platform linux/amd64 --progress=plain . (for x64)
// docker build --platform linux/arm64 --progress=plain --file=Dockerfile.aarch64 . (for aarch64)
// docker build --platform linux/amd64 --progress=plain --file=Dockerfile.amd64 . (for x64)
// Grab the image id from the console output, then, e.g.
// docker run 9c9f36564c148b275aeecc42749e7b4580ded79dcf51ff6ccc008c8861e7a979 > build/libs/vec/shared/$arch/libvec.so
//
Expand Down Expand Up @@ -51,6 +52,8 @@ model {
target("amd64") {
cCompiler.executable = "/usr/bin/gcc"
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c99", "-march=core-avx2", "-Wno-incompatible-pointer-types"]) }
cppCompiler.executable = "/usr/bin/g++"
cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=core-avx2"]) }
}
}
cl(VisualCpp) {
Expand All @@ -68,6 +71,7 @@ model {

target("amd64") {
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c99", "-march=core-avx2"]) }
cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=core-avx2"]) }
}
}
}
Expand All @@ -86,6 +90,15 @@ model {
srcDir "src/vec/headers/"
}
}
cpp {
source {
srcDir "src/vec/c/${platformName}/"
include "*.cpp"
}
exportedHeaders {
srcDir "src/vec/headers/"
}
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions libs/simdvec/native/publish_vec_binaries.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
exit 1;
fi

VERSION="1.0.9"
VERSION="1.0.10"
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
TEMP=$(mktemp -d)

Expand All @@ -33,11 +33,11 @@ echo 'Building Darwin binary...'

echo 'Building Linux binary...'
mkdir -p build/libs/vec/shared/aarch64/
DOCKER_IMAGE=$(docker build --platform linux/arm64 --quiet .)
DOCKER_IMAGE=$(docker build --platform linux/arm64 --quiet --file=Dockerfile.aarch64 .)
docker run $DOCKER_IMAGE > build/libs/vec/shared/aarch64/libvec.so

echo 'Building Linux x64 binary...'
DOCKER_IMAGE=$(docker build --platform linux/amd64 --quiet .)
DOCKER_IMAGE=$(docker build --platform linux/amd64 --quiet --file=Dockerfile.amd64 .)
mkdir -p build/libs/vec/shared/amd64
docker run --platform linux/amd64 $DOCKER_IMAGE > build/libs/vec/shared/amd64/libvec.so

Expand Down
38 changes: 21 additions & 17 deletions libs/simdvec/native/src/vec/c/amd64/vec.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,16 @@
#include <emmintrin.h>
#include <immintrin.h>

#ifndef DOT7U_STRIDE_BYTES_LEN
#define DOT7U_STRIDE_BYTES_LEN 32 // Must be a power of 2
#endif

#ifndef SQR7U_STRIDE_BYTES_LEN
#define SQR7U_STRIDE_BYTES_LEN 32 // Must be a power of 2
#ifndef STRIDE_BYTES_LEN
#define STRIDE_BYTES_LEN sizeof(__m256i) // Must be a power of 2
#endif

#ifdef _MSC_VER
#include <intrin.h>
#elif __GNUC__
#include <x86intrin.h>
#elif __clang__
#include <x86intrin.h>
#elif __GNUC__
#include <x86intrin.h>
#endif

// Multi-platform CPUID "intrinsic"; it takes as input a "functionNumber" (or "leaf", the eax registry). "Subleaf"
Expand Down Expand Up @@ -67,9 +63,19 @@ EXPORT int vec_caps() {
if (functionIds >= 7) {
cpuid(cpuInfo, 7);
int ebx = cpuInfo[1];
int ecx = cpuInfo[2];
// AVX2 flag is the 5th bit
// We assume that all processors that have AVX2 also have FMA3
return (ebx & (1 << 5)) != 0;
int avx2 = (ebx & 0x00000020) != 0;
int avx512 = (ebx & 0x00010000) != 0;
// int avx512_vnni = (ecx & 0x00000800) != 0;
// if (avx512 && avx512_vnni) {
if (avx512) {
return 2;
}
if (avx2) {
return 1;
}
}
return 0;
}
Expand All @@ -81,7 +87,7 @@ static inline int32_t dot7u_inner(int8_t* a, int8_t* b, size_t dims) {
__m256i acc1 = _mm256_setzero_si256();

#pragma GCC unroll 4
for(int i = 0; i < dims; i += DOT7U_STRIDE_BYTES_LEN) {
for(int i = 0; i < dims; i += STRIDE_BYTES_LEN) {
// Load packed 8-bit integers
__m256i va1 = _mm256_loadu_si256(a + i);
__m256i vb1 = _mm256_loadu_si256(b + i);
Expand All @@ -101,8 +107,8 @@ static inline int32_t dot7u_inner(int8_t* a, int8_t* b, size_t dims) {
EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) {
int32_t res = 0;
int i = 0;
if (dims > DOT7U_STRIDE_BYTES_LEN) {
i += dims & ~(DOT7U_STRIDE_BYTES_LEN - 1);
if (dims > STRIDE_BYTES_LEN) {
i += dims & ~(STRIDE_BYTES_LEN - 1);
res = dot7u_inner(a, b, i);
}
for (; i < dims; i++) {
Expand All @@ -118,15 +124,14 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) {
const __m256i ones = _mm256_set1_epi16(1);

#pragma GCC unroll 4
for(int i = 0; i < dims; i += SQR7U_STRIDE_BYTES_LEN) {
for(int i = 0; i < dims; i += STRIDE_BYTES_LEN) {
// Load packed 8-bit integers
__m256i va1 = _mm256_loadu_si256(a + i);
__m256i vb1 = _mm256_loadu_si256(b + i);

const __m256i dist1 = _mm256_sub_epi8(va1, vb1);
const __m256i abs_dist1 = _mm256_sign_epi8(dist1, dist1);
const __m256i sqr1 = _mm256_maddubs_epi16(abs_dist1, abs_dist1);

acc1 = _mm256_add_epi32(_mm256_madd_epi16(ones, sqr1), acc1);
}

Expand All @@ -137,8 +142,8 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) {
EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) {
int32_t res = 0;
int i = 0;
if (dims > SQR7U_STRIDE_BYTES_LEN) {
i += dims & ~(SQR7U_STRIDE_BYTES_LEN - 1);
if (dims > STRIDE_BYTES_LEN) {
i += dims & ~(STRIDE_BYTES_LEN - 1);
res = sqr7u_inner(a, b, i);
}
for (; i < dims; i++) {
Expand All @@ -147,4 +152,3 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) {
}
return res;
}

Loading