Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion java/lance-jni/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ pub fn get_vector_index_params(
"getRqParams",
|env, rq_obj| {
let num_bits = env.call_method(&rq_obj, "getNumBits", "()B", &[])?.b()? as u8;
Ok(RQBuildParams { num_bits })
Ok(RQBuildParams::new(num_bits))
},
)?;

Expand Down
5 changes: 5 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2905,6 +2905,11 @@ def create_index(
- index_file_version
The version of the index file. Default is "V3".

Optional parameters for `IVF_RQ`:

- num_bits
The number of bits for RQ (Rabit Quantization). Default is 1.

Optional parameters for `IVF_HNSW_*`:
max_level
Int, the maximum number of levels in the graph.
Expand Down
106 changes: 59 additions & 47 deletions rust/lance-index/benches/rq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ use lance_datagen::{BatchGeneratorBuilder, RowCount};
use lance_index::vector::bq::builder::RabitQuantizer;
use lance_index::vector::bq::storage::*;
use lance_index::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN};
use lance_index::vector::bq::RQRotationType;
use lance_index::vector::quantizer::{Quantization, QuantizerStorage};
use lance_index::vector::storage::{DistCalculator, VectorStore};
use lance_linalg::distance::DistanceType;

const DIM: usize = 128;
const TOTAL: usize = 16 * 1000;

fn mock_rq_storage(num_bits: u8) -> RabitQuantizationStorage {
fn mock_rq_storage(num_bits: u8, rotation_type: RQRotationType) -> RabitQuantizationStorage {
// generate random rq codes
let rq = RabitQuantizer::new::<Float32Type>(num_bits, DIM as i32);
let rq = RabitQuantizer::new_with_rotation::<Float32Type>(num_bits, DIM as i32, rotation_type);
let builder = BatchGeneratorBuilder::new()
.col(ROW_ID, lance_datagen::array::step::<UInt64Type>())
.col(
Expand All @@ -49,59 +50,70 @@ fn mock_rq_storage(num_bits: u8) -> RabitQuantizationStorage {
}

fn construct_dist_table(c: &mut Criterion) {
let rotation_types = [RQRotationType::Fast, RQRotationType::Matrix];
for num_bits in 1..=1 {
let rq = mock_rq_storage(num_bits);
let query = rand_type(&DataType::Float32)
.generate_default(RowCount::from(DIM as u64))
.unwrap();
c.bench_function(
format!(
"RQ{}: construct_dist_table: {},DIM={}",
num_bits,
DistanceType::L2,
DIM
)
.as_str(),
|b| {
b.iter(|| {
black_box(rq.dist_calculator(query.clone(), 0.0));
})
},
);
for rotation_type in rotation_types {
let rq = mock_rq_storage(num_bits, rotation_type);
let query = rand_type(&DataType::Float32)
.generate_default(RowCount::from(DIM as u64))
.unwrap();
c.bench_function(
format!(
"RQ{}({:?}): construct_dist_table: {},DIM={}",
num_bits,
rotation_type,
DistanceType::L2,
DIM
)
.as_str(),
|b| {
b.iter(|| {
black_box(rq.dist_calculator(query.clone(), 0.0));
})
},
);
}
}
}

fn compute_distances(c: &mut Criterion) {
let rotation_types = [RQRotationType::Fast, RQRotationType::Matrix];
for num_bits in 1..=1 {
let rq = mock_rq_storage(num_bits);
let query = rand_type(&DataType::Float32)
.generate_default(RowCount::from(DIM as u64))
.unwrap();
let dist_calc = rq.dist_calculator(query.clone(), 0.0);
for rotation_type in rotation_types {
let rq = mock_rq_storage(num_bits, rotation_type);
let query = rand_type(&DataType::Float32)
.generate_default(RowCount::from(DIM as u64))
.unwrap();
let dist_calc = rq.dist_calculator(query.clone(), 0.0);

c.bench_function(
format!("RQ{}: compute_distances: {},DIM={}", num_bits, TOTAL, DIM).as_str(),
|b| {
b.iter(|| {
black_box(dist_calc.distance_all(0));
})
},
);
c.bench_function(
format!(
"RQ{}({:?}): compute_distances: {},DIM={}",
num_bits, rotation_type, TOTAL, DIM
)
.as_str(),
|b| {
b.iter(|| {
black_box(dist_calc.distance_all(0));
})
},
);

c.bench_function(
format!(
"RQ{}: compute_distances_single: {},DIM={}",
num_bits, TOTAL, DIM
)
.as_str(),
|b| {
b.iter(|| {
for i in 0..TOTAL {
black_box(dist_calc.distance(i as u32));
}
})
},
);
c.bench_function(
format!(
"RQ{}({:?}): compute_distances_single: {},DIM={}",
num_bits, rotation_type, TOTAL, DIM
)
.as_str(),
|b| {
b.iter(|| {
for i in 0..TOTAL {
black_box(dist_calc.distance(i as u32));
}
})
},
);
}
}
}

Expand Down
60 changes: 58 additions & 2 deletions rust/lance-index/src/vector/bq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
//! Binary Quantization (BQ)

use std::iter::once;
use std::str::FromStr;
use std::sync::Arc;

use arrow_array::types::Float32Type;
use arrow_array::{cast::AsArray, Array, ArrayRef, UInt8Array};
use lance_core::{Error, Result};
use num_traits::Float;
use serde::{Deserialize, Serialize};
use snafu::location;

use crate::vector::quantizer::QuantizerBuildParams;

pub mod builder;
pub mod rotation;
pub mod storage;
pub mod transform;

Expand Down Expand Up @@ -80,14 +83,51 @@ fn binary_quantization<T: Float>(data: &[T]) -> impl Iterator<Item = u8> + '_ {
}))
}

#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RQRotationType {
#[default]
Fast,
Matrix,
}

impl FromStr for RQRotationType {
type Err = Error;

fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
match value.to_lowercase().as_str() {
"fast" | "fht_kac" | "fht-kac" => Ok(Self::Fast),
"matrix" | "dense" => Ok(Self::Matrix),
_ => Err(Error::invalid_input(
format!(
"Unknown RQ rotation type: {}. Expected one of: fast, matrix",
value
),
location!(),
)),
}
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RQBuildParams {
pub num_bits: u8,
pub rotation_type: RQRotationType,
}

impl RQBuildParams {
pub fn new(num_bits: u8) -> Self {
Self { num_bits }
Self {
num_bits,
rotation_type: RQRotationType::default(),
}
}

pub fn with_rotation_type(num_bits: u8, rotation_type: RQRotationType) -> Self {
Self {
num_bits,
rotation_type,
}
}
}

Expand All @@ -99,7 +139,10 @@ impl QuantizerBuildParams for RQBuildParams {

impl Default for RQBuildParams {
fn default() -> Self {
Self { num_bits: 1 }
Self {
num_bits: 1,
rotation_type: RQRotationType::default(),
}
}
}

Expand All @@ -126,4 +169,17 @@ mod tests {
test_bq::<f32>();
test_bq::<f64>();
}

#[test]
fn test_rotation_type_parse() {
assert_eq!(
"fast".parse::<RQRotationType>().unwrap(),
RQRotationType::Fast
);
assert_eq!(
"matrix".parse::<RQRotationType>().unwrap(),
RQRotationType::Matrix
);
assert!("invalid".parse::<RQRotationType>().is_err());
}
}
Loading