Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 78 additions & 77 deletions rust/lance-core/src/utils/mask.rs

Large diffs are not rendered by default.

32 changes: 16 additions & 16 deletions rust/lance-core/src/utils/mask/nullable.rs
Comment thread
yanghua marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use deepsize::DeepSizeOf;

use super::{RowAddrTreeMap, RowIdMask};
use super::{RowAddrMask, RowAddrTreeMap};

/// A set of row ids, with optional set of nulls.
///
Expand Down Expand Up @@ -121,18 +121,18 @@ impl std::ops::BitOrAssign<&Self> for NullableRowAddrSet {
}
}

/// A version of [`RowIdMask`] that supports nulls.
/// A version of [`RowAddrMask`] that supports nulls.
///
/// This mask handles three-valued logic for SQL expressions, where a filter can
/// evaluate to TRUE, FALSE, or NULL. The `selected` set includes rows that are
/// TRUE or NULL. The `nulls` set includes rows that are NULL.
#[derive(Clone, Debug)]
pub enum NullableRowIdMask {
pub enum NullableRowAddrMask {
AllowList(NullableRowAddrSet),
BlockList(NullableRowAddrSet),
}

impl NullableRowIdMask {
impl NullableRowAddrMask {
pub fn selected(&self, row_id: u64) -> bool {
match self {
Self::AllowList(NullableRowAddrSet { selected, nulls }) => {
Expand All @@ -144,19 +144,19 @@ impl NullableRowIdMask {
}
}

pub fn drop_nulls(self) -> RowIdMask {
pub fn drop_nulls(self) -> RowAddrMask {
match self {
Self::AllowList(NullableRowAddrSet { selected, nulls }) => {
RowIdMask::AllowList(selected - nulls)
RowAddrMask::AllowList(selected - nulls)
}
Self::BlockList(NullableRowAddrSet { selected, nulls }) => {
RowIdMask::BlockList(selected | nulls)
RowAddrMask::BlockList(selected | nulls)
}
}
}
}

impl std::ops::Not for NullableRowIdMask {
impl std::ops::Not for NullableRowAddrMask {
type Output = Self;

fn not(self) -> Self::Output {
Expand All @@ -167,7 +167,7 @@ impl std::ops::Not for NullableRowIdMask {
}
}

impl std::ops::BitAnd for NullableRowIdMask {
impl std::ops::BitAnd for NullableRowAddrMask {
type Output = Self;

fn bitand(self, rhs: Self) -> Self::Output {
Expand Down Expand Up @@ -214,7 +214,7 @@ impl std::ops::BitAnd for NullableRowIdMask {
}
}

impl std::ops::BitOr for NullableRowIdMask {
impl std::ops::BitOr for NullableRowAddrMask {
type Output = Self;

fn bitor(self, rhs: Self) -> Self::Output {
Expand Down Expand Up @@ -275,15 +275,15 @@ mod tests {
NullableRowAddrSet::new(rows(selected), rows(nulls))
}

fn allow(selected: &[u64], nulls: &[u64]) -> NullableRowIdMask {
NullableRowIdMask::AllowList(nullable_set(selected, nulls))
fn allow(selected: &[u64], nulls: &[u64]) -> NullableRowAddrMask {
NullableRowAddrMask::AllowList(nullable_set(selected, nulls))
}

fn block(selected: &[u64], nulls: &[u64]) -> NullableRowIdMask {
NullableRowIdMask::BlockList(nullable_set(selected, nulls))
fn block(selected: &[u64], nulls: &[u64]) -> NullableRowAddrMask {
NullableRowAddrMask::BlockList(nullable_set(selected, nulls))
}

fn assert_mask_selects(mask: &NullableRowIdMask, selected: &[u64], not_selected: &[u64]) {
fn assert_mask_selects(mask: &NullableRowAddrMask, selected: &[u64], not_selected: &[u64]) {
for &id in selected {
assert!(mask.selected(id), "Expected row {} to be selected", id);
}
Expand Down Expand Up @@ -520,7 +520,7 @@ mod tests {
let not_mask = !block_mask;

// NOT(BlockList) = AllowList
assert!(matches!(not_mask, NullableRowIdMask::AllowList(_)));
assert!(matches!(not_mask, NullableRowAddrMask::AllowList(_)));
}

#[test]
Expand Down
14 changes: 7 additions & 7 deletions rust/lance-index/src/prefilter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
use std::sync::Arc;

use async_trait::async_trait;
use lance_core::utils::mask::RowIdMask;
use lance_core::utils::mask::RowAddrMask;
use lance_core::Result;

/// A trait to be implemented by anything supplying a prefilter row id mask
/// A trait to be implemented by anything supplying a prefilter row addr mask
///
/// This trait is for internal use only and has no stability guarantees.
#[async_trait]
pub trait FilterLoader: Send + 'static {
async fn load(self: Box<Self>) -> Result<RowIdMask>;
async fn load(self: Box<Self>) -> Result<RowAddrMask>;
}

/// Filter out row ids that we know are not relevant to the query.
Expand All @@ -36,10 +36,10 @@ pub trait PreFilter: Send + Sync {
/// If the filter is empty.
fn is_empty(&self) -> bool;

/// Get the row id mask for this prefilter
/// Get the row addr mask for this prefilter
///
/// This method must be called after `wait_for_ready`
fn mask(&self) -> Arc<RowIdMask>;
fn mask(&self) -> Arc<RowAddrMask>;

/// Check whether a slice of row ids should be included in a query.
///
Expand All @@ -63,8 +63,8 @@ impl PreFilter for NoFilter {
true
}

fn mask(&self) -> Arc<RowIdMask> {
Arc::new(RowIdMask::all_rows())
fn mask(&self) -> Arc<RowAddrMask> {
Arc::new(RowAddrMask::all_rows())
}

fn filter_row_ids<'a>(&self, row_ids: Box<dyn Iterator<Item = &'a u64> + 'a>) -> Vec<u64> {
Expand Down
58 changes: 31 additions & 27 deletions rust/lance-index/src/scalar/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use super::{
SearchResult, TextQuery, TokenQuery,
};
use lance_core::{
utils::mask::{NullableRowIdMask, RowIdMask},
utils::mask::{NullableRowAddrMask, RowAddrMask},
Error, Result,
};
use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner};
Expand Down Expand Up @@ -907,17 +907,17 @@ pub static INDEX_EXPR_RESULT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {

#[derive(Debug)]
enum NullableIndexExprResult {
Exact(NullableRowIdMask),
AtMost(NullableRowIdMask),
AtLeast(NullableRowIdMask),
Exact(NullableRowAddrMask),
AtMost(NullableRowAddrMask),
AtLeast(NullableRowAddrMask),
}

impl From<SearchResult> for NullableIndexExprResult {
fn from(result: SearchResult) -> Self {
match result {
SearchResult::Exact(mask) => Self::Exact(NullableRowIdMask::AllowList(mask)),
SearchResult::AtMost(mask) => Self::AtMost(NullableRowIdMask::AllowList(mask)),
SearchResult::AtLeast(mask) => Self::AtLeast(NullableRowIdMask::AllowList(mask)),
SearchResult::Exact(mask) => Self::Exact(NullableRowAddrMask::AllowList(mask)),
SearchResult::AtMost(mask) => Self::AtMost(NullableRowAddrMask::AllowList(mask)),
SearchResult::AtLeast(mask) => Self::AtLeast(NullableRowAddrMask::AllowList(mask)),
}
}
}
Expand Down Expand Up @@ -983,19 +983,19 @@ impl NullableIndexExprResult {
#[derive(Debug)]
pub enum IndexExprResult {
// The answer is exactly the rows in the allow list minus the rows in the block list
Exact(RowIdMask),
Exact(RowAddrMask),
// The answer is at most the rows in the allow list minus the rows in the block list
// Some of the rows in the allow list may not be in the result and will need to be filtered
// by a recheck. Every row in the block list is definitely not in the result.
AtMost(RowIdMask),
AtMost(RowAddrMask),
// The answer is at least the rows in the allow list minus the rows in the block list
// Some of the rows in the block list might be in the result. Every row in the allow list is
// definitely in the result.
AtLeast(RowIdMask),
AtLeast(RowAddrMask),
}

impl IndexExprResult {
pub fn row_id_mask(&self) -> &RowIdMask {
pub fn row_addr_mask(&self) -> &RowAddrMask {
match self {
Self::Exact(mask) => mask,
Self::AtMost(mask) => mask,
Expand All @@ -1011,7 +1011,7 @@ impl IndexExprResult {
}
}

pub fn from_parts(mask: RowIdMask, discriminant: u32) -> Result<Self> {
pub fn from_parts(mask: RowAddrMask, discriminant: u32) -> Result<Self> {
match discriminant {
0 => Ok(Self::Exact(mask)),
1 => Ok(Self::AtMost(mask)),
Expand All @@ -1028,8 +1028,8 @@ impl IndexExprResult {
&self,
fragments_covered_by_result: &RoaringBitmap,
) -> Result<RecordBatch> {
let row_id_mask = self.row_id_mask();
let row_id_mask_arr = row_id_mask.into_arrow()?;
let row_addr_mask = self.row_addr_mask();
let row_addr_mask_arr = row_addr_mask.into_arrow()?;
let discriminant = self.discriminant();
let discriminant_arr =
Arc::new(UInt32Array::from(vec![discriminant, discriminant])) as Arc<dyn Array>;
Expand All @@ -1043,7 +1043,7 @@ impl IndexExprResult {
Ok(RecordBatch::try_new(
INDEX_EXPR_RESULT_SCHEMA.clone(),
vec![
Arc::new(row_id_mask_arr),
Arc::new(row_addr_mask_arr),
Arc::new(discriminant_arr),
Arc::new(fragments_covered_arr),
],
Expand Down Expand Up @@ -2213,7 +2213,7 @@ mod tests {
}

// AtMost: superset of matches (e.g., bloom filter says "might be in [1,2]")
let at_most = NullableIndexExprResult::AtMost(NullableRowIdMask::AllowList(
let at_most = NullableIndexExprResult::AtMost(NullableRowAddrMask::AllowList(
NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()),
));
// NOT(AtMost) should be AtLeast (definitely NOT in [1,2], might be elsewhere)
Expand All @@ -2223,7 +2223,7 @@ mod tests {
));

// AtLeast: subset of matches (e.g., definitely in [1,2], might be more)
let at_least = NullableIndexExprResult::AtLeast(NullableRowIdMask::AllowList(
let at_least = NullableIndexExprResult::AtLeast(NullableRowAddrMask::AllowList(
NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()),
));
// NOT(AtLeast) should be AtMost (might NOT be in [1,2], definitely elsewhere)
Expand All @@ -2233,7 +2233,7 @@ mod tests {
));

// Exact should stay Exact
let exact = NullableIndexExprResult::Exact(NullableRowIdMask::AllowList(
let exact = NullableIndexExprResult::Exact(NullableRowAddrMask::AllowList(
NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()),
));
assert!(matches!(
Expand All @@ -2248,21 +2248,25 @@ mod tests {

// Test that AND/OR correctly propagate certainty
let make_at_most = || {
NullableIndexExprResult::AtMost(NullableRowIdMask::AllowList(NullableRowAddrSet::new(
RowAddrTreeMap::from_iter(&[1, 2, 3]),
RowAddrTreeMap::new(),
)))
NullableIndexExprResult::AtMost(NullableRowAddrMask::AllowList(
NullableRowAddrSet::new(
RowAddrTreeMap::from_iter(&[1, 2, 3]),
RowAddrTreeMap::new(),
),
))
};

let make_at_least = || {
NullableIndexExprResult::AtLeast(NullableRowIdMask::AllowList(NullableRowAddrSet::new(
RowAddrTreeMap::from_iter(&[2, 3, 4]),
RowAddrTreeMap::new(),
)))
NullableIndexExprResult::AtLeast(NullableRowAddrMask::AllowList(
NullableRowAddrSet::new(
RowAddrTreeMap::from_iter(&[2, 3, 4]),
RowAddrTreeMap::new(),
),
))
};

let make_exact = || {
NullableIndexExprResult::Exact(NullableRowIdMask::AllowList(NullableRowAddrSet::new(
NullableIndexExprResult::Exact(NullableRowAddrMask::AllowList(NullableRowAddrSet::new(
RowAddrTreeMap::from_iter(&[1, 2]),
RowAddrTreeMap::new(),
)))
Expand Down
9 changes: 3 additions & 6 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@ use futures::{stream, FutureExt, StreamExt, TryStreamExt};
use itertools::Itertools;
use lance_arrow::{iter_str_array, RecordBatchExt};
use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
use lance_core::utils::mask::RowAddrTreeMap;
use lance_core::utils::{
mask::RowIdMask,
tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS},
};
use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap};
use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS};
use lance_core::{
container::list::ExpLinkedList,
utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
Expand Down Expand Up @@ -789,7 +786,7 @@ impl InvertedPartition {
&self,
params: &FtsSearchParams,
operator: Operator,
mask: Arc<RowIdMask>,
mask: Arc<RowAddrMask>,
postings: Vec<PostingIterator>,
metrics: &dyn MetricsCollector,
) -> Result<Vec<DocCandidate>> {
Expand Down
10 changes: 5 additions & 5 deletions rust/lance-index/src/scalar/inverted/wand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use arrow_array::{Array, UInt32Array};
use arrow_schema::DataType;
use itertools::Itertools;
use lance_core::utils::address::RowAddress;
use lance_core::utils::mask::RowIdMask;
use lance_core::utils::mask::RowAddrMask;
use lance_core::Result;

use crate::metrics::MetricsCollector;
Expand Down Expand Up @@ -341,15 +341,15 @@ impl<'a, S: Scorer> Wand<'a, S> {
pub(crate) fn search(
&mut self,
params: &FtsSearchParams,
mask: Arc<RowIdMask>,
mask: Arc<RowAddrMask>,
metrics: &dyn MetricsCollector,
) -> Result<Vec<DocCandidate>> {
let limit = params.limit.unwrap_or(usize::MAX);
if limit == 0 {
return Ok(vec![]);
}

match (mask.max_len(), mask.iter_ids()) {
match (mask.max_len(), mask.iter_addrs()) {
(Some(num_rows_matched), Some(row_ids))
if num_rows_matched * 100
<= FLAT_SEARCH_PERCENT_THRESHOLD.deref() * self.docs.len() as u64 =>
Expand Down Expand Up @@ -930,7 +930,7 @@ mod tests {
let result = wand
.search(
&FtsSearchParams::default(),
Arc::new(RowIdMask::default()),
Arc::new(RowAddrMask::default()),
&NoOpMetricsCollector,
)
.unwrap();
Expand Down Expand Up @@ -972,7 +972,7 @@ mod tests {

let result = wand.search(
&FtsSearchParams::default(),
Arc::new(RowIdMask::default()),
Arc::new(RowAddrMask::default()),
&NoOpMetricsCollector,
);

Expand Down
Loading
Loading