Skip to content

Commit

Permalink
[ENH]: replace get_* methods on memory blockfile impl with `get_ran…
Browse files Browse the repository at this point in the history
…ge()` (#2935)

## Description of changes

Replaces specialized methods like get_gt and get_lt with a single get_range() method that behaves similarly to the std BTreeMap::range() method. This reduces complexity/repetition and also enables queries that are bounded in both directions.

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
codetheweb authored Nov 4, 2024
1 parent 3cd54de commit ba09fa4
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 574 deletions.
163 changes: 62 additions & 101 deletions rust/blockstore/src/memory/reader_writer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::RangeBounds;

use super::{
super::{BlockfileError, Key, Value},
storage::{Readable, Storage, StorageBuilder, StorageManager, Writeable},
Expand Down Expand Up @@ -95,87 +97,27 @@ impl<
}
}

#[allow(clippy::type_complexity)]
pub(crate) fn get_by_prefix(
&'storage self,
prefix: &str,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
let values = V::get_by_prefix_from_storage(prefix, &self.storage);
if values.is_empty() {
return Err(Box::new(BlockfileError::NotFoundError));
}
let values = values
.iter()
.map(|(key, value)| (K::try_from(&key.key).unwrap(), value.clone()))
.collect();
Ok(values)
}

#[allow(clippy::type_complexity)]
pub(crate) fn get_gt(
&'storage self,
prefix: &str,
key: K,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
let key = key.into();
let values = V::read_gt_from_storage(prefix, key, &self.storage);
if values.is_empty() {
return Err(Box::new(BlockfileError::NotFoundError));
}
let values = values
.iter()
.map(|(key, value)| (K::try_from(&key.key).unwrap(), value.clone()))
.collect();
Ok(values)
}

#[allow(clippy::type_complexity)]
pub(crate) fn get_lt(
&'storage self,
prefix: &str,
key: K,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
let key = key.into();
let values = V::read_lt_from_storage(prefix, key, &self.storage);
if values.is_empty() {
return Err(Box::new(BlockfileError::NotFoundError));
}
let values = values
.iter()
.map(|(key, value)| (K::try_from(&key.key).unwrap(), value.clone()))
.collect();
Ok(values)
}

#[allow(clippy::type_complexity)]
pub(crate) fn get_gte(
pub(crate) fn get_range<'prefix, PrefixRange, KeyRange>(
&'storage self,
prefix: &str,
key: K,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
let key = key.into();
let values = V::read_gte_from_storage(prefix, key, &self.storage);
prefix_range: PrefixRange,
key_range: KeyRange,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>>
where
PrefixRange: RangeBounds<&'prefix str>,
KeyRange: RangeBounds<K>,
{
let values = V::read_range_from_storage(
prefix_range,
(
key_range.start_bound().map(|k| k.clone().into()),
key_range.end_bound().map(|k| k.clone().into()),
),
&self.storage,
);
if values.is_empty() {
return Err(Box::new(BlockfileError::NotFoundError));
}
let values = values
.iter()
.map(|(key, value)| (K::try_from(&key.key).unwrap(), value.clone()))
.collect();
Ok(values)
}

#[allow(clippy::type_complexity)]
pub(crate) fn get_lte(
&'storage self,
prefix: &str,
key: K,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
let key = key.into();
let values = V::read_lte_from_storage(prefix, key, &self.storage);
if values.is_empty() {
return Err(Box::new(BlockfileError::NotFoundError));
}
let values = values
.iter()
.map(|(key, value)| (K::try_from(&key.key).unwrap(), value.clone()))
Expand Down Expand Up @@ -210,6 +152,8 @@ impl<

#[cfg(test)]
mod tests {
use std::ops::Bound;

use super::*;
use chroma_types::{Chunk, DataRecord, LogRecord, Operation, OperationRecord};

Expand Down Expand Up @@ -369,7 +313,7 @@ mod tests {

let reader: MemoryBlockfileReader<&str, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_by_prefix("prefix").unwrap();
let values = reader.get_range("prefix"..="prefix", ..).unwrap();
assert_eq!(values.len(), 2);
assert!(values
.iter()
Expand All @@ -390,7 +334,7 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gt("prefix", 3);
let values = reader.get_range("prefix"..="prefix", (Bound::Excluded(3), Bound::Unbounded));
assert!(values.is_err());
}

Expand All @@ -405,7 +349,9 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gt("prefix", 0).unwrap();
let values = reader
.get_range("prefix"..="prefix", (Bound::Excluded(0), Bound::Unbounded))
.unwrap();
assert_eq!(values.len(), 3);
assert!(values
.iter()
Expand All @@ -429,7 +375,9 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gt("prefix", 1).unwrap();
let values = reader
.get_range("prefix"..="prefix", (Bound::Excluded(1), Bound::Unbounded))
.unwrap();
assert_eq!(values.len(), 2);
assert!(values
.iter()
Expand All @@ -450,7 +398,10 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gt("prefix", 3.0);
let values = reader.get_range(
"prefix"..="prefix",
(Bound::Excluded(3.0), Bound::Unbounded),
);
assert!(values.is_err());
}

Expand All @@ -465,7 +416,12 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gt("prefix", 0.0).unwrap();
let values = reader
.get_range(
"prefix"..="prefix",
(Bound::Excluded(0.0), Bound::Unbounded),
)
.unwrap();
assert_eq!(values.len(), 3);
assert!(values
.iter()
Expand All @@ -489,7 +445,12 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gt("prefix", 1.0).unwrap();
let values = reader
.get_range(
"prefix"..="prefix",
(Bound::Excluded(1.0), Bound::Unbounded),
)
.unwrap();
assert_eq!(values.len(), 2);
assert!(values
.iter()
Expand All @@ -510,7 +471,7 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gte("prefix", 4);
let values = reader.get_range("prefix"..="prefix", 4..);
assert!(values.is_err());
}

Expand All @@ -525,7 +486,7 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gte("prefix", 1).unwrap();
let values = reader.get_range("prefix"..="prefix", 1..).unwrap();
assert_eq!(values.len(), 3);
assert!(values
.iter()
Expand All @@ -549,7 +510,7 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gte("prefix", 2).unwrap();
let values = reader.get_range("prefix"..="prefix", 2..).unwrap();
assert_eq!(values.len(), 2);
assert!(values
.iter()
Expand All @@ -570,7 +531,7 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gte("prefix", 3.5);
let values = reader.get_range("prefix"..="prefix", 3.5..);
assert!(values.is_err());
}

Expand All @@ -585,7 +546,7 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gte("prefix", 0.5).unwrap();
let values = reader.get_range("prefix"..="prefix", 0.5..).unwrap();
assert_eq!(values.len(), 3);
assert!(values
.iter()
Expand All @@ -609,7 +570,7 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_gte("prefix", 1.5).unwrap();
let values = reader.get_range("prefix"..="prefix", 1.5..).unwrap();
assert_eq!(values.len(), 2);
assert!(values
.iter()
Expand All @@ -630,7 +591,7 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lt("prefix", 1);
let values = reader.get_range("prefix"..="prefix", ..1);
assert!(values.is_err());
}

Expand All @@ -645,7 +606,7 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lt("prefix", 4).unwrap();
let values = reader.get_range("prefix"..="prefix", ..4).unwrap();
assert_eq!(values.len(), 3);
assert!(values
.iter()
Expand All @@ -669,7 +630,7 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lt("prefix", 3).unwrap();
let values = reader.get_range("prefix"..="prefix", ..3).unwrap();
assert_eq!(values.len(), 2);
assert!(values
.iter()
Expand All @@ -690,7 +651,7 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lt("prefix", 0.5);
let values = reader.get_range("prefix"..="prefix", ..0.5);
assert!(values.is_err());
}

Expand All @@ -705,7 +666,7 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lt("prefix", 3.5).unwrap();
let values = reader.get_range("prefix"..="prefix", ..3.5).unwrap();
assert_eq!(values.len(), 3);
assert!(values
.iter()
Expand All @@ -729,7 +690,7 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lt("prefix", 2.5).unwrap();
let values = reader.get_range("prefix"..="prefix", ..2.5).unwrap();
assert_eq!(values.len(), 2);
assert!(values
.iter()
Expand All @@ -750,7 +711,7 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lte("prefix", 0);
let values = reader.get_range("prefix"..="prefix", ..=0);
assert!(values.is_err());
}

Expand All @@ -765,7 +726,7 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lte("prefix", 3).unwrap();
let values = reader.get_range("prefix"..="prefix", ..=3).unwrap();
assert_eq!(values.len(), 3);
assert!(values
.iter()
Expand All @@ -789,7 +750,7 @@ mod tests {

let reader: MemoryBlockfileReader<u32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lte("prefix", 2).unwrap();
let values = reader.get_range("prefix"..="prefix", ..=2).unwrap();
assert_eq!(values.len(), 2);
assert!(values
.iter()
Expand All @@ -810,7 +771,7 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lte("prefix", 0.5);
let values = reader.get_range("prefix"..="prefix", ..=0.5);
assert!(values.is_err());
}

Expand All @@ -825,7 +786,7 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lte("prefix", 3.0).unwrap();
let values = reader.get_range("prefix"..="prefix", ..=3.0).unwrap();
assert_eq!(values.len(), 3);
assert!(values
.iter()
Expand All @@ -849,7 +810,7 @@ mod tests {

let reader: MemoryBlockfileReader<f32, &str> =
MemoryBlockfileReader::open(writer.id, storage_manager);
let values = reader.get_lte("prefix", 2.0).unwrap();
let values = reader.get_range("prefix"..="prefix", ..=2.0).unwrap();
assert_eq!(values.len(), 2);
assert!(values
.iter()
Expand Down
Loading

0 comments on commit ba09fa4

Please sign in to comment.