Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ configurations {
}

var zstdVersion = "1.5.7"
var vecVersion = "1.0.41"
var vecVersion = "1.0.42"

repositories {
exclusiveContent {
Expand Down
2 changes: 1 addition & 1 deletion libs/simdvec/native/publish_vec_binaries.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
exit 1;
fi

VERSION="1.0.41"
VERSION="1.0.42"
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
TEMP=$(mktemp -d)

Expand Down
223 changes: 156 additions & 67 deletions libs/simdvec/native/src/vec/c/aarch64/vec_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,98 @@ static inline void sqri8_inner_bulk(
const int32_t count,
f32_t* results
) {
for (int c = 0; c < count; c++) {
const int blk = dims & ~15;
int c = 0;

// Process 4 vectors at a time; this helps the CPU scheduler/prefetcher.
// Loading multiple memory locations while computing gives the prefetcher
// information on where the data to load will be next, and keeps the CPU
// execution units busy.
// Our benchmarks show that this "hint" is more effective than using
// explicit prefetch instructions (e.g. __builtin_prefetch) on many ARM
// processors (e.g. Graviton)
for (; c + 3 < count; c += 4) {
const int8_t* a0 = a + mapper(c, offsets) * pitch;
const int8_t* a1 = a + mapper(c + 1, offsets) * pitch;
const int8_t* a2 = a + mapper(c + 2, offsets) * pitch;
const int8_t* a3 = a + mapper(c + 3, offsets) * pitch;

int32x4_t acc0 = vdupq_n_s32(0);
int32x4_t acc1 = vdupq_n_s32(0);
int32x4_t acc2 = vdupq_n_s32(0);
int32x4_t acc3 = vdupq_n_s32(0);
int32x4_t acc4 = vdupq_n_s32(0);
int32x4_t acc5 = vdupq_n_s32(0);
int32x4_t acc6 = vdupq_n_s32(0);
int32x4_t acc7 = vdupq_n_s32(0);

for (int i = 0; i < blk; i += 16) {
int8x16_t vb = vld1q_s8(b + i);

int8x16_t v0 = vld1q_s8(a0 + i);
int16x8_t d0_lo = vsubl_s8(vget_low_s8(v0), vget_low_s8(vb));
int16x8_t d0_hi = vsubl_s8(vget_high_s8(v0), vget_high_s8(vb));
acc0 = vmlal_s16(acc0, vget_low_s16(d0_lo), vget_low_s16(d0_lo));
acc1 = vmlal_s16(acc1, vget_high_s16(d0_lo), vget_high_s16(d0_lo));
acc0 = vmlal_s16(acc0, vget_low_s16(d0_hi), vget_low_s16(d0_hi));
acc1 = vmlal_s16(acc1, vget_high_s16(d0_hi), vget_high_s16(d0_hi));

int8x16_t v1 = vld1q_s8(a1 + i);
int16x8_t d1_lo = vsubl_s8(vget_low_s8(v1), vget_low_s8(vb));
int16x8_t d1_hi = vsubl_s8(vget_high_s8(v1), vget_high_s8(vb));
acc2 = vmlal_s16(acc2, vget_low_s16(d1_lo), vget_low_s16(d1_lo));
acc3 = vmlal_s16(acc3, vget_high_s16(d1_lo), vget_high_s16(d1_lo));
acc2 = vmlal_s16(acc2, vget_low_s16(d1_hi), vget_low_s16(d1_hi));
acc3 = vmlal_s16(acc3, vget_high_s16(d1_hi), vget_high_s16(d1_hi));

int8x16_t v2 = vld1q_s8(a2 + i);
int16x8_t d2_lo = vsubl_s8(vget_low_s8(v2), vget_low_s8(vb));
int16x8_t d2_hi = vsubl_s8(vget_high_s8(v2), vget_high_s8(vb));
acc4 = vmlal_s16(acc4, vget_low_s16(d2_lo), vget_low_s16(d2_lo));
acc5 = vmlal_s16(acc5, vget_high_s16(d2_lo), vget_high_s16(d2_lo));
acc4 = vmlal_s16(acc4, vget_low_s16(d2_hi), vget_low_s16(d2_hi));
acc5 = vmlal_s16(acc5, vget_high_s16(d2_hi), vget_high_s16(d2_hi));

int8x16_t v3 = vld1q_s8(a3 + i);
int16x8_t d3_lo = vsubl_s8(vget_low_s8(v3), vget_low_s8(vb));
int16x8_t d3_hi = vsubl_s8(vget_high_s8(v3), vget_high_s8(vb));
acc6 = vmlal_s16(acc6, vget_low_s16(d3_lo), vget_low_s16(d3_lo));
acc7 = vmlal_s16(acc7, vget_high_s16(d3_lo), vget_high_s16(d3_lo));
acc6 = vmlal_s16(acc6, vget_low_s16(d3_hi), vget_low_s16(d3_hi));
acc7 = vmlal_s16(acc7, vget_high_s16(d3_hi), vget_high_s16(d3_hi));
}
int32x4_t acc01 = vaddq_s32(acc0, acc1);
int32x4_t acc23 = vaddq_s32(acc2, acc3);
int32x4_t acc45 = vaddq_s32(acc4, acc5);
int32x4_t acc67 = vaddq_s32(acc6, acc7);

int32_t acc_scalar0 = vaddvq_s32(acc01);
int32_t acc_scalar1 = vaddvq_s32(acc23);
int32_t acc_scalar2 = vaddvq_s32(acc45);
int32_t acc_scalar3 = vaddvq_s32(acc67);
if (blk != dims) {
// scalar tail
for (int t = blk; t < dims; t++) {
const int8_t bb = b[t];
int32_t diff0 = a0[t] - bb;
int32_t diff1 = a1[t] - bb;
int32_t diff2 = a2[t] - bb;
int32_t diff3 = a3[t] - bb;

acc_scalar0 += diff0 * diff0;
acc_scalar1 += diff1 * diff1;
acc_scalar2 += diff2 * diff2;
acc_scalar3 += diff3 * diff3;
}
}
results[c + 0] = (f32_t)acc_scalar0;
results[c + 1] = (f32_t)acc_scalar1;
results[c + 2] = (f32_t)acc_scalar2;
results[c + 3] = (f32_t)acc_scalar3;
}

// Tail-handling: remaining vectors
for (; c < count; c++) {
const int8_t* a0 = a + mapper(c, offsets) * pitch;
results[c] = (f32_t)vec_sqri8(a0, b, dims);
}
Expand Down Expand Up @@ -809,71 +900,6 @@ EXPORT int64_t vec_dotd1q4(const int8_t* a, const int8_t* query, const int32_t l
return dotd1q4_inner(a, query, length);
}

EXPORT int64_t vec_dotd2q4(
const int8_t* a,
const int8_t* query,
const int32_t length
) {
int64_t lower = dotd1q4_inner(a, query, length/2);
int64_t upper = dotd1q4_inner(a + length/2, query, length/2);
return lower + (upper << 1);
}

EXPORT int64_t vec_dotd4q4(const int8_t* a, const int8_t* query, const int32_t length) {
const int32_t bit_length = length / 4;
int64_t p0 = dotd1q4_inner(a + 0 * bit_length, query, bit_length);
int64_t p1 = dotd1q4_inner(a + 1 * bit_length, query, bit_length);
int64_t p2 = dotd1q4_inner(a + 2 * bit_length, query, bit_length);
int64_t p3 = dotd1q4_inner(a + 3 * bit_length, query, bit_length);
return p0 + (p1 << 1) + (p2 << 2) + (p3 << 3);
}

template <int64_t(*mapper)(const int32_t, const int32_t*)>
static inline void dotd4q4_inner_bulk(
const int8_t* a,
const int8_t* query,
const int32_t length,
const int32_t pitch,
const int32_t* offsets,
const int32_t count,
f32_t* results
) {
const int32_t bit_length = length / 4;

for (int c = 0; c < count; c++) {
const int8_t* a0 = a + mapper(c, offsets) * pitch;

int64_t p0 = dotd1q4_inner(a0 + 0 * bit_length, query, bit_length);
int64_t p1 = dotd1q4_inner(a0 + 1 * bit_length, query, bit_length);
int64_t p2 = dotd1q4_inner(a0 + 2 * bit_length, query, bit_length);
int64_t p3 = dotd1q4_inner(a0 + 3 * bit_length, query, bit_length);

results[c] = (f32_t)(p0 + (p1 << 1) + (p2 << 2) + (p3 << 3));
}
}

EXPORT void vec_dotd4q4_bulk(
const int8_t* a,
const int8_t* query,
const int32_t length,
const int32_t count,
f32_t* results
) {
dotd4q4_inner_bulk<identity_mapper>(a, query, length, length, NULL, count, results);
}

EXPORT void vec_dotd4q4_bulk_offsets(
const int8_t* a,
const int8_t* query,
const int32_t length,
const int32_t pitch,
const int32_t* offsets,
const int32_t count,
f32_t* results
) {
dotd4q4_inner_bulk<array_mapper>(a, query, length, pitch, offsets, count, results);
}

template <int64_t(*mapper)(const int32_t, const int32_t*)>
static inline void dotd1q4_inner_bulk(
const int8_t* a,
Expand Down Expand Up @@ -1013,6 +1039,15 @@ EXPORT void vec_dotd1q4_bulk_offsets(
dotd1q4_inner_bulk<array_mapper>(a, query, length, pitch, offsets, count, results);
}

EXPORT int64_t vec_dotd2q4(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just moves the methods around so they're in a consistent order

const int8_t* a,
const int8_t* query,
const int32_t length
) {
int64_t lower = dotd1q4_inner(a, query, length/2);
int64_t upper = dotd1q4_inner(a + length/2, query, length/2);
return lower + (upper << 1);
}

template <int64_t(*mapper)(const int32_t, const int32_t*)>
static inline void dotd2q4_inner_bulk(
Expand All @@ -1026,7 +1061,6 @@ static inline void dotd2q4_inner_bulk(
) {
int c = 0;
const int bit_length = length/2;
// TODO: specialised implementation
for (; c < count; c++) {
const int8_t* a0 = a + mapper(c, offsets) * pitch;
int64_t lower = dotd1q4_inner(a0, query, bit_length);
Expand Down Expand Up @@ -1054,3 +1088,58 @@ EXPORT void vec_dotd2q4_bulk_offsets(
f32_t* results) {
dotd2q4_inner_bulk<array_mapper>(a, query, length, pitch, offsets, count, results);
}

EXPORT int64_t vec_dotd4q4(const int8_t* a, const int8_t* query, const int32_t length) {
const int32_t bit_length = length / 4;
int64_t p0 = dotd1q4_inner(a + 0 * bit_length, query, bit_length);
int64_t p1 = dotd1q4_inner(a + 1 * bit_length, query, bit_length);
int64_t p2 = dotd1q4_inner(a + 2 * bit_length, query, bit_length);
int64_t p3 = dotd1q4_inner(a + 3 * bit_length, query, bit_length);
return p0 + (p1 << 1) + (p2 << 2) + (p3 << 3);
}

template <int64_t(*mapper)(const int32_t, const int32_t*)>
static inline void dotd4q4_inner_bulk(
const int8_t* a,
const int8_t* query,
const int32_t length,
const int32_t pitch,
const int32_t* offsets,
const int32_t count,
f32_t* results
) {
const int32_t bit_length = length / 4;

for (int c = 0; c < count; c++) {
const int8_t* a0 = a + mapper(c, offsets) * pitch;

int64_t p0 = dotd1q4_inner(a0 + 0 * bit_length, query, bit_length);
int64_t p1 = dotd1q4_inner(a0 + 1 * bit_length, query, bit_length);
int64_t p2 = dotd1q4_inner(a0 + 2 * bit_length, query, bit_length);
int64_t p3 = dotd1q4_inner(a0 + 3 * bit_length, query, bit_length);

results[c] = (f32_t)(p0 + (p1 << 1) + (p2 << 2) + (p3 << 3));
}
}

EXPORT void vec_dotd4q4_bulk(
const int8_t* a,
const int8_t* query,
const int32_t length,
const int32_t count,
f32_t* results
) {
dotd4q4_inner_bulk<identity_mapper>(a, query, length, length, NULL, count, results);
}

EXPORT void vec_dotd4q4_bulk_offsets(
const int8_t* a,
const int8_t* query,
const int32_t length,
const int32_t pitch,
const int32_t* offsets,
const int32_t count,
f32_t* results
) {
dotd4q4_inner_bulk<array_mapper>(a, query, length, pitch, offsets, count, results);
}