Skip to content

Commit

Permalink
Add index_uncheck to avoid boundary check (tikv#1468)
Browse files Browse the repository at this point in the history
Signed-off-by: EricZequan <[email protected]>
  • Loading branch information
EricZequan authored May 14, 2024
1 parent 5ac0eef commit 57ba446
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 11 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions components/tidb_query_datatype/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,11 @@ tidb_query_common = { workspace = true }
tikv_alloc = { workspace = true }
tikv_util = { workspace = true }
tipb = { workspace = true }

[dev-dependencies]
criterion = "0.3"

[[bench]]
name = "bench_vector_distance"
path = "benches/bench_vector_distance.rs"
harness = false
191 changes: 191 additions & 0 deletions components/tidb_query_datatype/benches/bench_vector_distance.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Copyright 2024 TiKV Project Authors. Licensed under Apache-2.0.

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use tidb_query_datatype::codec::mysql::VectorFloat32Ref;

fn bench_l1_distance_3d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.1, 2.2, 3.3];
let vb: Vec<f32> = vec![1.1, 2.2, 3.3];
let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();
let vec_vb = VectorFloat32Ref::from_f32(vb.as_slice()).unwrap();

c.bench_function("l1_distance_3d", |b| {
b.iter(|| {
black_box(black_box(vec_va).l1_distance(black_box(vec_vb)).unwrap());
});
});
}

fn bench_l1_distance_784d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.0; 784];
let vb: Vec<f32> = vec![1.0; 784];
let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();
let vec_vb = VectorFloat32Ref::from_f32(vb.as_slice()).unwrap();

c.bench_function("l1_distance_784d", |b| {
b.iter(|| {
black_box(black_box(vec_va).l1_distance(black_box(vec_vb)).unwrap());
});
});
}

fn bench_l2_squared_distance_3d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.1, 2.2, 3.3];
let vb: Vec<f32> = vec![1.1, 2.2, 3.3];
let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();
let vec_vb = VectorFloat32Ref::from_f32(vb.as_slice()).unwrap();

c.bench_function("l2_squared_distance_3d", |b| {
b.iter(|| {
black_box(
black_box(vec_va)
.l2_squared_distance(black_box(vec_vb))
.unwrap(),
);
});
});
}

fn bench_l2_squared_distance_784d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.0; 784];
let vb: Vec<f32> = vec![1.0; 784];
let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();
let vec_vb = VectorFloat32Ref::from_f32(vb.as_slice()).unwrap();

c.bench_function("l2_squared_distance_784d", |b| {
b.iter(|| {
black_box(
black_box(vec_va)
.l2_squared_distance(black_box(vec_vb))
.unwrap(),
);
});
});
}

fn bench_l2_distance_3d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.1, 2.2, 3.3];
let vb: Vec<f32> = vec![1.1, 2.2, 3.3];
let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();
let vec_vb = VectorFloat32Ref::from_f32(vb.as_slice()).unwrap();

c.bench_function("l2_distance_3d", |b| {
b.iter(|| {
black_box(black_box(vec_va).l2_distance(black_box(vec_vb)).unwrap());
});
});
}

fn bench_l2_distance_784d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.0; 784];
let vb: Vec<f32> = vec![1.0; 784];
let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();
let vec_vb = VectorFloat32Ref::from_f32(vb.as_slice()).unwrap();

c.bench_function("l2_distance_784d", |b| {
b.iter(|| {
black_box(black_box(vec_va).l2_distance(black_box(vec_vb)).unwrap());
});
});
}

fn bench_inner_product_3d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.1, 2.2, 3.3];
let vb: Vec<f32> = vec![1.1, 2.2, 3.3];
let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();
let vec_vb = VectorFloat32Ref::from_f32(vb.as_slice()).unwrap();

c.bench_function("inner_product_3d", |b| {
b.iter(|| {
black_box(black_box(vec_va).inner_product(black_box(vec_vb)).unwrap());
});
});
}

fn bench_inner_product_784d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.0; 784];
let vb: Vec<f32> = vec![1.0; 784];
let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();
let vec_vb = VectorFloat32Ref::from_f32(vb.as_slice()).unwrap();

c.bench_function("inner_product_784d", |b| {
b.iter(|| {
black_box(black_box(vec_va).inner_product(black_box(vec_vb)).unwrap());
});
});
}

fn bench_cosine_distance_3d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.1, 2.2, 3.3];
let vb: Vec<f32> = vec![1.1, 2.2, 3.3];
let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();
let vec_vb = VectorFloat32Ref::from_f32(vb.as_slice()).unwrap();

c.bench_function("cosine_distance_3d", |b| {
b.iter(|| {
black_box(
black_box(vec_va)
.cosine_distance(black_box(vec_vb))
.unwrap(),
);
});
});
}

fn bench_cosine_distance_784d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.0; 784];
let vb: Vec<f32> = vec![1.0; 784];
let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();
let vec_vb = VectorFloat32Ref::from_f32(vb.as_slice()).unwrap();

c.bench_function("cosine_distance_784d", |b| {
b.iter(|| {
black_box(
black_box(vec_va)
.cosine_distance(black_box(vec_vb))
.unwrap(),
);
});
});
}

fn bench_l2_norm_3d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.1, 2.2, 3.3];

let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();

c.bench_function("l2_norm_3d", |b| {
b.iter(|| {
black_box(black_box(vec_va).l2_norm());
});
});
}

fn bench_l2_norm_784d(c: &mut Criterion) {
let va: Vec<f32> = vec![1.0; 784];

let vec_va = VectorFloat32Ref::from_f32(va.as_slice()).unwrap();

c.bench_function("l2_norm_784d", |b| {
b.iter(|| {
black_box(black_box(vec_va).l2_norm());
});
});
}

criterion_group!(
benches,
bench_l1_distance_3d,
bench_l1_distance_784d,
bench_l2_squared_distance_3d,
bench_l2_squared_distance_784d,
bench_l2_distance_3d,
bench_l2_distance_784d,
bench_inner_product_3d,
bench_inner_product_784d,
bench_cosine_distance_3d,
bench_cosine_distance_784d,
bench_l2_norm_3d,
bench_l2_norm_784d,
);
criterion_main!(benches);
45 changes: 34 additions & 11 deletions components/tidb_query_datatype/src/codec/mysql/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,24 +174,45 @@ impl<'a> VectorFloat32Ref<'a> {
}

fn index(&self, idx: usize) -> f32 {
let byte_index: usize = idx * 4;
if byte_index + 4 > self.value.len() {
if idx > self.len() {
panic!(
"Index out of bounds: index = {}, length = {}",
idx,
self.len()
);
}
let float_ptr = unsafe { self.value.as_ptr().add(byte_index) as *const f32 };
unsafe { float_ptr.read_unaligned() }
let byte_index: usize = idx * 4;
unsafe {
let float_ptr = self.value.as_ptr().add(byte_index) as *const f32;
float_ptr.read_unaligned()
}
}

// An unsafe function to get the 'f32' value without boundary check.
// it will check the bounding in debug model and remove the check in
// release.
unsafe fn index_unchecked(&self, idx: usize) -> f32 {
#[cfg(debug_assertions)]
{
if idx > self.len() {
panic!(
"Index out of bounds: index = {}, length = {}",
idx,
self.len()
);
}
}
let byte_index: usize = idx * 4;
let float_ptr = self.value.as_ptr().add(byte_index) as *const f32;
float_ptr.read_unaligned()
}

pub fn l2_squared_distance(&self, b: VectorFloat32Ref<'a>) -> Result<f64> {
self.check_dims(b)?;
let mut distance: f32 = 0.0;

for i in 0..self.len() {
let diff = self.index(i) - b.index(i);
let diff = unsafe { self.index_unchecked(i) - b.index_unchecked(i) };
distance += diff * diff;
}

Expand All @@ -206,7 +227,7 @@ impl<'a> VectorFloat32Ref<'a> {
self.check_dims(b)?;
let mut distance: f32 = 0.0;
for i in 0..self.len() {
distance += self.index(i) * b.index(i);
distance += unsafe { self.index_unchecked(i) * b.index_unchecked(i) };
}

Ok(distance as f64)
Expand All @@ -218,9 +239,11 @@ impl<'a> VectorFloat32Ref<'a> {
let mut norma: f32 = 0.0;
let mut normb: f32 = 0.0;
for i in 0..self.len() {
distance += self.index(i) * b.index(i);
norma += self.index(i) * self.index(i);
normb += b.index(i) * b.index(i);
unsafe {
distance += self.index_unchecked(i) * b.index_unchecked(i);
norma += self.index_unchecked(i) * self.index_unchecked(i);
normb += b.index_unchecked(i) * b.index_unchecked(i);
}
}

let similarity = (distance as f64) / ((norma as f64) * (normb as f64)).sqrt();
Expand All @@ -236,7 +259,7 @@ impl<'a> VectorFloat32Ref<'a> {
self.check_dims(b)?;
let mut distance: f32 = 0.0;
for i in 0..self.len() {
let diff = self.index(i) - b.index(i);
let diff = unsafe { self.index_unchecked(i) - b.index_unchecked(i) };
distance += diff.abs();
}

Expand All @@ -248,7 +271,7 @@ impl<'a> VectorFloat32Ref<'a> {
// precision during calculation.
let mut norm: f64 = 0.0;
for i in 0..self.len() {
let v = self.index(i) as f64;
let v = unsafe { self.index_unchecked(i) as f64 };
norm += v * v;
}

Expand Down

0 comments on commit 57ba446

Please sign in to comment.