Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 75 additions & 74 deletions rust/lance-core/src/utils/mask.rs

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions rust/lance-core/src/utils/mask/nullable.rs
Comment thread
yanghua marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use deepsize::DeepSizeOf;

use super::{RowAddrTreeMap, RowIdMask};
use super::{RowAddrMask, RowAddrTreeMap};

/// A set of row ids, with optional set of nulls.
///
Expand Down Expand Up @@ -121,7 +121,7 @@ impl std::ops::BitOrAssign<&Self> for NullableRowAddrSet {
}
}

/// A version of [`RowIdMask`] that supports nulls.
/// A version of [`RowAddrMask`] that supports nulls.
///
/// This mask handles three-valued logic for SQL expressions, where a filter can
/// evaluate to TRUE, FALSE, or NULL. The `selected` set includes rows that are
Expand All @@ -144,13 +144,13 @@ impl NullableRowIdMask {
}
}

pub fn drop_nulls(self) -> RowIdMask {
pub fn drop_nulls(self) -> RowAddrMask {
match self {
Self::AllowList(NullableRowAddrSet { selected, nulls }) => {
RowIdMask::AllowList(selected - nulls)
RowAddrMask::AllowList(selected - nulls)
}
Self::BlockList(NullableRowAddrSet { selected, nulls }) => {
RowIdMask::BlockList(selected | nulls)
RowAddrMask::BlockList(selected | nulls)
}
}
}
Expand Down
14 changes: 7 additions & 7 deletions rust/lance-index/src/prefilter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
use std::sync::Arc;

use async_trait::async_trait;
use lance_core::utils::mask::RowIdMask;
use lance_core::utils::mask::RowAddrMask;
use lance_core::Result;

/// A trait to be implemented by anything supplying a prefilter row id mask
/// A trait to be implemented by anything supplying a prefilter row addr mask
///
/// This trait is for internal use only and has no stability guarantees.
#[async_trait]
pub trait FilterLoader: Send + 'static {
async fn load(self: Box<Self>) -> Result<RowIdMask>;
async fn load(self: Box<Self>) -> Result<RowAddrMask>;
}

/// Filter out row ids that we know are not relevant to the query.
Expand All @@ -36,10 +36,10 @@ pub trait PreFilter: Send + Sync {
/// If the filter is empty.
fn is_empty(&self) -> bool;

/// Get the row id mask for this prefilter
/// Get the row addr mask for this prefilter
///
/// This method must be called after `wait_for_ready`
fn mask(&self) -> Arc<RowIdMask>;
fn mask(&self) -> Arc<RowAddrMask>;

/// Check whether a slice of row ids should be included in a query.
///
Expand All @@ -63,8 +63,8 @@ impl PreFilter for NoFilter {
true
}

fn mask(&self) -> Arc<RowIdMask> {
Arc::new(RowIdMask::all_rows())
fn mask(&self) -> Arc<RowAddrMask> {
Arc::new(RowAddrMask::all_rows())
}

fn filter_row_ids<'a>(&self, row_ids: Box<dyn Iterator<Item = &'a u64> + 'a>) -> Vec<u64> {
Expand Down
18 changes: 9 additions & 9 deletions rust/lance-index/src/scalar/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use super::{
SearchResult, TextQuery, TokenQuery,
};
use lance_core::{
utils::mask::{NullableRowIdMask, RowIdMask},
utils::mask::{NullableRowIdMask, RowAddrMask},
Error, Result,
};
use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner};
Expand Down Expand Up @@ -983,19 +983,19 @@ impl NullableIndexExprResult {
#[derive(Debug)]
pub enum IndexExprResult {
// The answer is exactly the rows in the allow list minus the rows in the block list
Exact(RowIdMask),
Exact(RowAddrMask),
// The answer is at most the rows in the allow list minus the rows in the block list
// Some of the rows in the allow list may not be in the result and will need to be filtered
// by a recheck. Every row in the block list is definitely not in the result.
AtMost(RowIdMask),
AtMost(RowAddrMask),
// The answer is at least the rows in the allow list minus the rows in the block list
// Some of the rows in the block list might be in the result. Every row in the allow list is
// definitely in the result.
AtLeast(RowIdMask),
AtLeast(RowAddrMask),
}

impl IndexExprResult {
pub fn row_id_mask(&self) -> &RowIdMask {
pub fn row_addr_mask(&self) -> &RowAddrMask {
match self {
Self::Exact(mask) => mask,
Self::AtMost(mask) => mask,
Expand All @@ -1011,7 +1011,7 @@ impl IndexExprResult {
}
}

pub fn from_parts(mask: RowIdMask, discriminant: u32) -> Result<Self> {
pub fn from_parts(mask: RowAddrMask, discriminant: u32) -> Result<Self> {
match discriminant {
0 => Ok(Self::Exact(mask)),
1 => Ok(Self::AtMost(mask)),
Expand All @@ -1028,8 +1028,8 @@ impl IndexExprResult {
&self,
fragments_covered_by_result: &RoaringBitmap,
) -> Result<RecordBatch> {
let row_id_mask = self.row_id_mask();
let row_id_mask_arr = row_id_mask.into_arrow()?;
let row_addr_mask = self.row_addr_mask();
let row_addr_mask_arr = row_addr_mask.into_arrow()?;
let discriminant = self.discriminant();
let discriminant_arr =
Arc::new(UInt32Array::from(vec![discriminant, discriminant])) as Arc<dyn Array>;
Expand All @@ -1043,7 +1043,7 @@ impl IndexExprResult {
Ok(RecordBatch::try_new(
INDEX_EXPR_RESULT_SCHEMA.clone(),
vec![
Arc::new(row_id_mask_arr),
Arc::new(row_addr_mask_arr),
Arc::new(discriminant_arr),
Arc::new(fragments_covered_arr),
],
Expand Down
9 changes: 3 additions & 6 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@ use futures::{stream, FutureExt, StreamExt, TryStreamExt};
use itertools::Itertools;
use lance_arrow::{iter_str_array, RecordBatchExt};
use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
use lance_core::utils::mask::RowAddrTreeMap;
use lance_core::utils::{
mask::RowIdMask,
tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS},
};
use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap};
use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS};
use lance_core::{
container::list::ExpLinkedList,
utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
Expand Down Expand Up @@ -789,7 +786,7 @@ impl InvertedPartition {
&self,
params: &FtsSearchParams,
operator: Operator,
mask: Arc<RowIdMask>,
mask: Arc<RowAddrMask>,
postings: Vec<PostingIterator>,
metrics: &dyn MetricsCollector,
) -> Result<Vec<DocCandidate>> {
Expand Down
10 changes: 5 additions & 5 deletions rust/lance-index/src/scalar/inverted/wand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use arrow_array::{Array, UInt32Array};
use arrow_schema::DataType;
use itertools::Itertools;
use lance_core::utils::address::RowAddress;
use lance_core::utils::mask::RowIdMask;
use lance_core::utils::mask::RowAddrMask;
use lance_core::Result;

use crate::metrics::MetricsCollector;
Expand Down Expand Up @@ -341,15 +341,15 @@ impl<'a, S: Scorer> Wand<'a, S> {
pub(crate) fn search(
&mut self,
params: &FtsSearchParams,
mask: Arc<RowIdMask>,
mask: Arc<RowAddrMask>,
metrics: &dyn MetricsCollector,
) -> Result<Vec<DocCandidate>> {
let limit = params.limit.unwrap_or(usize::MAX);
if limit == 0 {
return Ok(vec![]);
}

match (mask.max_len(), mask.iter_ids()) {
match (mask.max_len(), mask.iter_addrs()) {
(Some(num_rows_matched), Some(row_ids))
if num_rows_matched * 100
<= FLAT_SEARCH_PERCENT_THRESHOLD.deref() * self.docs.len() as u64 =>
Expand Down Expand Up @@ -930,7 +930,7 @@ mod tests {
let result = wand
.search(
&FtsSearchParams::default(),
Arc::new(RowIdMask::default()),
Arc::new(RowAddrMask::default()),
&NoOpMetricsCollector,
)
.unwrap();
Expand Down Expand Up @@ -972,7 +972,7 @@ mod tests {

let result = wand.search(
&FtsSearchParams::default(),
Arc::new(RowIdMask::default()),
Arc::new(RowAddrMask::default()),
&NoOpMetricsCollector,
);

Expand Down
18 changes: 9 additions & 9 deletions rust/lance-index/src/vector/flat/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ impl IvfSubIndex for FlatIndex {
}
}
false => {
let row_id_mask = prefilter.mask();
let row_addr_mask = prefilter.mask();
if is_range_query {
let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
for (id, &row_id) in row_ids.enumerate() {
if !row_id_mask.selected(row_id) {
for (id, &row_addr) in row_ids.enumerate() {
if !row_addr_mask.selected(row_addr) {
continue;
}
let dist = dist_calc.distance(id as u32).into();
Expand All @@ -141,24 +141,24 @@ impl IvfSubIndex for FlatIndex {
}

if res.len() < k {
res.push(OrderedNode::new(row_id, dist));
res.push(OrderedNode::new(row_addr, dist));
} else if res.peek().unwrap().dist > dist {
res.pop();
res.push(OrderedNode::new(row_id, dist));
res.push(OrderedNode::new(row_addr, dist));
}
}
} else {
for (id, &row_id) in row_ids.enumerate() {
if !row_id_mask.selected(row_id) {
for (id, &row_addr) in row_ids.enumerate() {
if !row_addr_mask.selected(row_addr) {
continue;
}

let dist = dist_calc.distance(id as u32).into();
if res.len() < k {
res.push(OrderedNode::new(row_id, dist));
res.push(OrderedNode::new(row_addr, dist));
} else if res.peek().unwrap().dist > dist {
res.pop();
res.push(OrderedNode::new(row_id, dist));
res.push(OrderedNode::new(row_addr, dist));
}
}
}
Expand Down
Loading