Skip to content

Commit

Permalink
[ENH]: replace get_block_ids_* with get_block_ids_range() in Spar…
Browse files Browse the repository at this point in the history
…seIndex (#2921)

## Description of changes

Replaces specialized methods like `get_block_ids_gt` and `get_block_ids_lt` with a single `get_block_ids_range()` method that behaves similarly to the std `BTreeMap::range()` method. This reduces complexity/repetition and also enables queries that are bounded in both directions.

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*

n/a
  • Loading branch information
codetheweb authored Nov 4, 2024
1 parent 154587c commit 8eae185
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 202 deletions.
22 changes: 17 additions & 5 deletions rust/blockstore/src/arrow/blockfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,13 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
key: K,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
// Get all block ids that contain keys > key from sparse index for this prefix.
let block_ids = self.root.sparse_index.get_block_ids_gt(prefix, key.clone());
let block_ids = self.root.sparse_index.get_block_ids_range(
prefix..=prefix,
(
std::ops::Bound::Excluded(key.clone()),
std::ops::Bound::Unbounded,
),
);
let mut result: Vec<(K, V)> = vec![];
// Read all the blocks individually to get keys > key.
for block_id in block_ids {
Expand Down Expand Up @@ -515,7 +521,10 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
key: K,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
// Get all block ids that contain keys < key from sparse index.
let block_ids = self.root.sparse_index.get_block_ids_lt(prefix, key.clone());
let block_ids = self
.root
.sparse_index
.get_block_ids_range(prefix..=prefix, ..key.clone());
let mut result: Vec<(K, V)> = vec![];
// Read all the blocks individually to get keys < key.
for block_id in block_ids {
Expand Down Expand Up @@ -550,7 +559,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
let block_ids = self
.root
.sparse_index
.get_block_ids_gte(prefix, key.clone());
.get_block_ids_range(prefix..=prefix, key.clone()..);
let mut result: Vec<(K, V)> = vec![];
// Read all the blocks individually to get keys >= key.
for block_id in block_ids {
Expand Down Expand Up @@ -585,7 +594,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
let block_ids = self
.root
.sparse_index
.get_block_ids_lte(prefix, key.clone());
.get_block_ids_range(prefix..=prefix, ..=key.clone());
let mut result: Vec<(K, V)> = vec![];
// Read all the blocks individually to get keys <= key.
for block_id in block_ids {
Expand Down Expand Up @@ -615,7 +624,10 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
&'me self,
prefix: &str,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
let block_ids = self.root.sparse_index.get_block_ids_prefix(prefix);
let block_ids = self
.root
.sparse_index
.get_block_ids_range::<K, _, _>(prefix..=prefix, ..);
let mut result: Vec<(K, V)> = vec![];
for block_id in block_ids {
let block_opt = match self.get_block(block_id).await {
Expand Down
280 changes: 83 additions & 197 deletions rust/blockstore/src/arrow/sparse_index.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use super::types::ArrowReadableKey;
use crate::key::{CompositeKey, KeyWrapper};
use crate::key::CompositeKey;
use chroma_error::ChromaError;
use core::panic;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::fmt::Debug;
use std::ops::{Bound, RangeBounds};
use std::sync::Arc;
use thiserror::Error;
use uuid::Uuid;
Expand Down Expand Up @@ -361,213 +361,99 @@ impl SparseIndexReader {
result_uuids
}

/// Get all block ids that have keys with the given prefix
pub(super) fn get_block_ids_prefix(&self, prefix: &str) -> Vec<Uuid> {
let data = &self.data;
let forward = &data.forward;
let curr_iter = forward.iter();
let mut next_iter = forward.iter().skip(1);
let mut block_ids = vec![];
for (curr_key, curr_block_value) in curr_iter {
let non_start_curr_key: Option<&CompositeKey> = match curr_key {
SparseIndexDelimiter::Start => None,
SparseIndexDelimiter::Key(k) => Some(k),
};
if let Some((next_key, _)) = next_iter.next() {
// This can't be a start key but we still need to extract it.
let non_start_next_key: Option<&CompositeKey> = match next_key {
SparseIndexDelimiter::Start => {
panic!("Invariant violation. Sparse index is not valid.");
}
SparseIndexDelimiter::Key(k) => Some(k),
};
// If delimeter starts with the same prefix then there will be keys inside the
// block with this prefix.
if non_start_curr_key.is_some()
&& prefix == non_start_curr_key.unwrap().prefix.as_str()
{
block_ids.push(curr_block_value.id);
}
// If prefix is between the current delim and next delim then there could
// be keys in this block that have this prefix.
if (non_start_curr_key.is_none()
|| prefix > non_start_curr_key.unwrap().prefix.as_str())
&& (prefix <= non_start_next_key.unwrap().prefix.as_str())
{
block_ids.push(curr_block_value.id);
}
} else {
// Last block.
if non_start_curr_key.is_none()
|| prefix >= non_start_curr_key.unwrap().prefix.as_str()
{
block_ids.push(curr_block_value.id);
}
}
}
block_ids
}

/// Get all block ids that have keys with the given prefix and key greater than the given key
pub(super) fn get_block_ids_gt<'a, K: ArrowReadableKey<'a> + Into<KeyWrapper>>(
pub(super) fn get_block_ids_range<'prefix, 'referred_data, K, PrefixRange, KeyRange>(
&self,
prefix: &str,
key: K,
) -> Vec<Uuid> {
let data = &self.data;
let forward = &data.forward;
let curr_iter = forward.iter();
let mut next_iter = forward.iter().skip(1);
let mut block_ids = vec![];
for (curr_delim, curr_block_value) in curr_iter {
let curr_key = match curr_delim {
SparseIndexDelimiter::Start => None,
// These key ranges are flattened instead of using a single RangeBounds<CompositeKey> because not all keys have a well-defined min and max value. E.x. if the key is a string, there would be no way to get the range for all keys within a specific prefix.
prefix_range: PrefixRange,
key_range: KeyRange,
) -> Vec<Uuid>
where
K: ArrowReadableKey<'referred_data>,
PrefixRange: RangeBounds<&'prefix str>,
KeyRange: RangeBounds<K>,
{
let forward = &self.data.forward;

// We do not materialize the last key of each block, so we must check the next block's start key to determine if the current block's end key is within the query range.
let start_keys_offset_by_1_iter = forward
.iter()
.skip(1)
.map(|(k, _)| match k {
SparseIndexDelimiter::Start => {
panic!("Invariant violation. Sparse index is not valid.");
}
SparseIndexDelimiter::Key(k) => Some(k),
};
let mut next_key: Option<&CompositeKey> = None;
if let Some((next_delim, _)) = next_iter.next() {
next_key = match next_delim {
SparseIndexDelimiter::Start => {
panic!("Invariant violation. Sparse index is not valid.")
}
SparseIndexDelimiter::Key(k) => Some(k),
})
.chain(std::iter::once(None));

forward
.iter()
.zip(start_keys_offset_by_1_iter)
.map(|((start_key, block_uuid), end_key)| (block_uuid, start_key, end_key))
.filter(|(_, block_start_key, block_end_key)| {
let prefix_start_valid = match block_start_key {
SparseIndexDelimiter::Start => true,
SparseIndexDelimiter::Key(start_key) => match prefix_range.start_bound() {
Bound::Included(prefix_start) => *prefix_start >= start_key.prefix.as_str(),
Bound::Excluded(prefix_start) => *prefix_start > start_key.prefix.as_str(),
Bound::Unbounded => true,
},
};
}
if (curr_key.is_none() || curr_key.unwrap().prefix.as_str() < prefix)
&& (next_key.is_none() || next_key.unwrap().prefix.as_str() >= prefix)
{
block_ids.push(curr_block_value.id);
}
if let Some(curr_key) = curr_key {
if (curr_key.key > key.clone().into())
|| next_key.is_none()
|| next_key.unwrap().key > key.clone().into()
{
block_ids.push(curr_block_value.id);

if !prefix_start_valid {
return false;
}
}
}
block_ids
}

/// Get all block ids that have keys with the given prefix and key less than the given key
pub(super) fn get_block_ids_lt<'a, K: ArrowReadableKey<'a> + Into<KeyWrapper>>(
&self,
prefix: &str,
key: K,
) -> Vec<Uuid> {
let data = &self.data;
let forward = &data.forward;
let curr_iter = forward.iter();
let mut next_iter = forward.iter().skip(1);
let mut block_ids = vec![];
for (curr_delim, curr_block_value) in curr_iter {
let curr_key = match curr_delim {
SparseIndexDelimiter::Start => None,
SparseIndexDelimiter::Key(k) => Some(k),
};
let mut next_key: Option<&CompositeKey> = None;
if let Some((next_delim, _)) = next_iter.next() {
next_key = match next_delim {
SparseIndexDelimiter::Start => {
panic!("Invariant violation. Sparse index is not valid.")
}
SparseIndexDelimiter::Key(k) => Some(k),
let prefix_end_valid = match prefix_range.end_bound() {
Bound::Included(prefix_end) => match block_end_key {
Some(end_key) => *prefix_end <= end_key.prefix.as_str(),
None => true,
},
Bound::Excluded(prefix_end) => match block_end_key {
Some(end_key) => *prefix_end < end_key.prefix.as_str(),
None => true,
},
Bound::Unbounded => true,
};
}
if (curr_key.is_none() || curr_key.unwrap().prefix.as_str() < prefix)
&& (next_key.is_none() || next_key.unwrap().prefix.as_str() >= prefix)
{
block_ids.push(curr_block_value.id);
}
if let Some(curr_key) = curr_key {
if curr_key.prefix.as_str() == prefix && curr_key.key < key.clone().into() {
block_ids.push(curr_block_value.id);

if !prefix_end_valid {
return false;
}
}
}
block_ids
}

/// Get all block ids that have keys with the given prefix and key greater than or equal to the given key
pub(super) fn get_block_ids_gte<'a, K: ArrowReadableKey<'a> + Into<KeyWrapper>>(
&self,
prefix: &str,
key: K,
) -> Vec<Uuid> {
let data = &self.data;
let forward = &data.forward;
let curr_iter = forward.iter();
let mut next_iter = forward.iter().skip(1);
let mut block_ids = vec![];
for (curr_delim, curr_block_value) in curr_iter {
let curr_key = match curr_delim {
SparseIndexDelimiter::Start => None,
SparseIndexDelimiter::Key(k) => Some(k),
};
let mut next_key: Option<&CompositeKey> = None;
if let Some((next_delim, _)) = next_iter.next() {
next_key = match next_delim {
SparseIndexDelimiter::Start => {
panic!("Invariant violation. Sparse index is not valid.")
}
SparseIndexDelimiter::Key(k) => Some(k),
let key_start_valid = match block_end_key {
Some(block_end_key) => match key_range.start_bound() {
Bound::Included(key_range_start) => {
key_range_start.clone().into() <= block_end_key.key
}
Bound::Excluded(key_range_start) => {
key_range_start.clone().into() < block_end_key.key
}
Bound::Unbounded => true,
},
None => true,
};
}
if (curr_key.is_none() || curr_key.unwrap().prefix.as_str() < prefix)
&& (next_key.is_none() || next_key.unwrap().prefix.as_str() >= prefix)
{
block_ids.push(curr_block_value.id);
}
if let Some(curr_key) = curr_key {
if curr_key.key >= key.clone().into()
|| next_key.is_none()
|| next_key.unwrap().key >= key.clone().into()
{
block_ids.push(curr_block_value.id);

if !key_start_valid {
return false;
}
}
}
block_ids
}

/// Get all block ids that have keys with the given prefix and key less than or equal to the given key
pub(super) fn get_block_ids_lte<'a, K: ArrowReadableKey<'a> + Into<KeyWrapper>>(
&self,
prefix: &str,
key: K,
) -> Vec<Uuid> {
let data = &self.data;
let forward = &data.forward;
let curr_iter = forward.iter();
let mut next_iter = forward.iter().skip(1);
let mut block_ids = vec![];
for (curr_delim, curr_block_value) in curr_iter {
let curr_key = match curr_delim {
SparseIndexDelimiter::Start => None,
SparseIndexDelimiter::Key(k) => Some(k),
};
let mut next_key: Option<&CompositeKey> = None;
if let Some((next_delim, _)) = next_iter.next() {
next_key = match next_delim {
SparseIndexDelimiter::Start => {
panic!("Invariant violation. Sparse index is not valid.")
}
SparseIndexDelimiter::Key(k) => Some(k),
let key_end_valid = match block_start_key {
SparseIndexDelimiter::Start => true,
SparseIndexDelimiter::Key(start_key) => match key_range.end_bound() {
Bound::Included(key_range_end) => {
key_range_end.clone().into() >= start_key.key
}
Bound::Excluded(key_range_end) => {
key_range_end.clone().into() > start_key.key
}
Bound::Unbounded => true,
},
};
}
if (curr_key.is_none() || curr_key.unwrap().prefix.as_str() < prefix)
&& (next_key.is_none() || next_key.unwrap().prefix.as_str() >= prefix)
{
block_ids.push(curr_block_value.id);
}
if let Some(curr_key) = curr_key {
if curr_key.prefix.as_str() == prefix && curr_key.key <= key.clone().into() {
block_ids.push(curr_block_value.id);
}
}
}
block_ids

key_end_valid
})
.map(|(sparse_index_value, _, _)| sparse_index_value.id)
.collect()
}

/// Fork the sparse index to create a new sparse index
Expand Down

0 comments on commit 8eae185

Please sign in to comment.