Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(JoinInner): use AVs for keys instead of PVs #1194

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 26 additions & 82 deletions crates/core/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::relation::RowCount;
use spacetimedb_primitives::*;
use spacetimedb_sats::db::def::TableDef;
use spacetimedb_sats::relation::{DbTable, FieldExpr, FieldExprRef, Header, Relation};
use spacetimedb_sats::relation::{DbTable, FieldExpr, Header, Relation};
use spacetimedb_sats::{AlgebraicValue, ProductValue};
use spacetimedb_vm::errors::ErrorVm;
use spacetimedb_vm::eval::IterRows;
use spacetimedb_vm::eval::{join_inner, IterRows};
use spacetimedb_vm::expr::*;
use spacetimedb_vm::iterators::RelIter;
use spacetimedb_vm::program::{ProgramVm, Sources};
Expand Down Expand Up @@ -201,12 +201,12 @@ pub fn build_query<'a>(
}
}
Query::JoinInner(join) => {
let result = result
let lhs = result
.take()
.map(Ok)
.unwrap_or_else(|| get_table(ctx, stdb, tx, &query.source, sources))?;
let iter = join_inner(ctx, stdb, tx, result, join, sources)?;
Box::new(iter)
let rhs = build_query(ctx, stdb, tx, &join.rhs, sources)?;
join_inner(lhs, rhs, join)?
}
})
}
Expand All @@ -216,53 +216,6 @@ pub fn build_query<'a>(
.unwrap_or_else(|| get_table(ctx, stdb, tx, &query.source, sources))
}

fn join_inner<'a>(
ctx: &'a ExecutionContext,
db: &'a RelationalDB,
tx: &'a TxMode<'a>,
lhs: impl RelOps<'a> + 'a,
rhs: &'a JoinExpr,
sources: &mut impl SourceProvider<'a>,
) -> Result<impl RelOps<'a> + 'a, ErrorVm> {
let semi = rhs.semi;

let col_lhs = FieldExprRef::Name(rhs.col_lhs);
let col_rhs = FieldExprRef::Name(rhs.col_rhs);
let key_lhs = [col_lhs];
let key_rhs = [col_rhs];

let rhs = build_query(ctx, db, tx, &rhs.rhs, sources)?;
let key_lhs_header = lhs.head().clone();
let key_rhs_header = rhs.head().clone();
let col_lhs_header = lhs.head().clone();
let col_rhs_header = rhs.head().clone();

let header = if semi {
col_lhs_header.clone()
} else {
Arc::new(col_lhs_header.extend(&col_rhs_header))
};

lhs.join_inner(
rhs,
header,
move |row| Ok(row.project(&key_lhs, &key_lhs_header)?),
move |row| Ok(row.project(&key_rhs, &key_rhs_header)?),
move |l, r| {
let l = l.get(col_lhs, &col_lhs_header)?;
let r = r.get(col_rhs, &col_rhs_header)?;
Ok(l == r)
},
move |l, r| {
if semi {
l
} else {
l.extend(r)
}
},
)
}

/// Resolve `query` to a table iterator,
/// either taken from an in-memory table, in the case of [`SourceExpr::InMemory`],
/// or from a physical table, in the case of [`SourceExpr::DbTable`].
Expand Down Expand Up @@ -678,12 +631,27 @@ pub(crate) mod tests {
Ok((schema, row))
}

fn run_query<const N: usize>(
db: &RelationalDB,
q: QueryExpr,
sources: SourceSet<Vec<ProductValue>, N>,
) -> MemTable {
let ctx = ExecutionContext::default();
db.with_read_only(&ctx, |tx| {
let mut tx_mode = (&*tx).into();
let p = &mut DbProgram::new(&ctx, db, &mut tx_mode, AuthCtx::for_testing());
match run_ast(p, q.into(), sources) {
Code::Table(x) => x,
x => panic!("invalid result {x}"),
}
})
}

#[test]
fn test_db_query_inner_join() -> ResultTest<()> {
let stdb = TestDB::durable()?;

let ctx = ExecutionContext::default();
let (schema, _) = stdb.with_auto_commit(&ctx, |tx| create_inv_table(&stdb, tx))?;
let (schema, _) = stdb.with_auto_commit(&ExecutionContext::default(), |tx| create_inv_table(&stdb, tx))?;
let table_id = schema.table_id;

let data = mem_table_one_u64(u32::MAX.into());
Expand All @@ -692,15 +660,7 @@ pub(crate) mod tests {
let rhs_source_expr = sources.add_mem_table(data);
let q =
QueryExpr::new(&*schema).with_join_inner(rhs_source_expr, FieldName::new(table_id, 0.into()), rhs, false);

let result = stdb.with_read_only(&ctx, |tx| {
let mut tx_mode = (&*tx).into();
let p = &mut DbProgram::new(&ctx, &stdb, &mut tx_mode, AuthCtx::for_testing());
match run_ast(p, q.into(), sources) {
Code::Table(x) => x,
x => panic!("invalid result {x}"),
}
});
let result = run_query(&stdb, q, sources);

// The expected result.
let inv = ProductType::from([AlgebraicType::U64, AlgebraicType::String, AlgebraicType::U64]);
Expand All @@ -724,18 +684,9 @@ pub(crate) mod tests {
let rhs = *data.get_field_pos(0).unwrap();
let mut sources = SourceSet::<_, 1>::empty();
let rhs_source_expr = sources.add_mem_table(data);

let q =
QueryExpr::new(&*schema).with_join_inner(rhs_source_expr, FieldName::new(table_id, 0.into()), rhs, true);

let result = stdb.with_read_only(&ctx, |tx| {
let mut tx_mode = (&*tx).into();
let p = &mut DbProgram::new(&ctx, &stdb, &mut tx_mode, AuthCtx::for_testing());
match run_ast(p, q.into(), sources) {
Code::Table(x) => x,
x => panic!("invalid result {x}"),
}
});
let result = run_query(&stdb, q, sources);

// The expected result.
let input = mem_table(schema.table_id, schema.get_row_type().clone(), vec![row]);
Expand All @@ -745,16 +696,9 @@ pub(crate) mod tests {
}

fn check_catalog(db: &RelationalDB, name: &str, row: ProductValue, q: QueryExpr, schema: &TableSchema) {
let ctx = ExecutionContext::default();
let result = db.with_read_only(&ctx, |tx| {
let tx_mode = &mut (&*tx).into();
let p = &mut DbProgram::new(&ctx, db, tx_mode, AuthCtx::for_testing());
run_ast(p, q.into(), [].into())
});

// The expected result.
let result = run_query(db, q, [].into());
let input = MemTable::from_iter(Header::from(schema).into(), [row]);
assert_eq!(result, Code::Table(input), "{}", name);
assert_eq!(result, input, "{}", name);
}

#[test]
Expand Down
68 changes: 26 additions & 42 deletions crates/vm/src/eval.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::errors::ErrorVm;
use crate::expr::{Code, SourceExpr, SourceSet};
use crate::expr::{Code, JoinExpr, SourceExpr, SourceSet};
use crate::expr::{Expr, Query};
use crate::iterators::RelIter;
use crate::program::{ProgramVm, Sources};
use crate::rel_ops::RelOps;
use crate::relation::RelValue;
use spacetimedb_sats::relation::{FieldExprRef, Relation};
use spacetimedb_sats::relation::Relation;
use spacetimedb_sats::ProductValue;
use std::sync::Arc;

Expand Down Expand Up @@ -46,54 +46,38 @@ pub fn build_query<'a, const N: usize>(
}
}
Query::JoinInner(q) => {
// Pick the smaller set to be at the left.
let col_lhs = FieldExprRef::Name(q.col_lhs);
let col_rhs = FieldExprRef::Name(q.col_rhs);

let rhs = build_source_expr_query(sources, &q.rhs.source);
let rhs = build_query(rhs, &q.rhs.query, sources)?;

let lhs = result;
let key_lhs_header = lhs.head().clone();
let key_rhs_header = rhs.head().clone();
let col_lhs_header = lhs.head().clone();
let col_rhs_header = rhs.head().clone();

if q.semi {
let iter = lhs.join_inner(
rhs,
col_lhs_header.clone(),
move |row| Ok(row.get(col_lhs, &key_lhs_header)?.into_owned().into()),
move |row| Ok(row.get(col_rhs, &key_rhs_header)?.into_owned().into()),
move |l, r| {
let l = l.get(col_lhs, &col_lhs_header)?;
let r = r.get(col_rhs, &col_rhs_header)?;
Ok(l == r)
},
|l, _| l,
)?;
Box::new(iter)
} else {
let iter = lhs.join_inner(
rhs,
Arc::new(col_lhs_header.extend(&col_rhs_header)),
move |row| Ok(row.get(col_lhs, &key_lhs_header)?.into_owned().into()),
move |row| Ok(row.get(col_rhs, &key_rhs_header)?.into_owned().into()),
move |l, r| {
let l = l.get(col_lhs, &col_lhs_header)?;
let r = r.get(col_rhs, &col_rhs_header)?;
Ok(l == r)
},
move |l, r| l.extend(r),
)?;
Box::new(iter)
}
join_inner(result, rhs, q)?
}
};
}
Ok(result)
}

pub fn join_inner<'a>(
lhs: impl RelOps<'a> + 'a,
rhs: impl RelOps<'a> + 'a,
q: &'a JoinExpr,
) -> Result<Box<IterRows<'a>>, ErrorVm> {
let lhs_head = lhs.head();
let rhs_head = rhs.head();
let col_lhs = lhs_head.column_pos_or_err(q.col_lhs)?;
let col_rhs = rhs_head.column_pos_or_err(q.col_rhs)?;

let key_lhs = move |row: &RelValue<'_>| row.read_column(col_lhs.idx()).unwrap().into_owned();
let key_rhs = move |row: &RelValue<'_>| row.read_column(col_rhs.idx()).unwrap().into_owned();
let pred = move |l: &RelValue<'_>, r: &RelValue<'_>| l.read_column(col_lhs.idx()) == r.read_column(col_rhs.idx());

Ok(if q.semi {
let head = lhs_head.clone();
Box::new(lhs.join_inner(rhs, head, key_lhs, key_rhs, pred, move |l, _| l)?)
} else {
let head = Arc::new(lhs_head.extend(rhs_head));
Box::new(lhs.join_inner(rhs, head, key_lhs, key_rhs, pred, move |l, r| l.extend(r))?)
})
}

pub(crate) fn build_source_expr_query<'a, const N: usize>(
sources: Sources<'_, N>,
source: &SourceExpr,
Expand Down
38 changes: 17 additions & 21 deletions crates/vm/src/rel_ops.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::errors::ErrorVm;
use crate::relation::RelValue;
use spacetimedb_data_structures::map::HashMap;
use spacetimedb_sats::product_value::ProductValue;
use spacetimedb_sats::relation::{FieldExpr, Header, RowCount};
use spacetimedb_sats::AlgebraicValue;
use std::sync::Arc;

/// A trait for dealing with fallible iterators for the database.
Expand Down Expand Up @@ -89,10 +89,10 @@ pub trait RelOps<'a> {
) -> Result<JoinInner<'a, Self, Rhs, KeyLhs, KeyRhs, Pred, Proj>, ErrorVm>
where
Self: Sized,
Pred: FnMut(&RelValue<'a>, &RelValue<'a>) -> Result<bool, ErrorVm>,
Pred: FnMut(&RelValue<'a>, &RelValue<'a>) -> bool,
Proj: FnMut(RelValue<'a>, RelValue<'a>) -> RelValue<'a>,
KeyLhs: FnMut(&RelValue<'a>) -> Result<ProductValue, ErrorVm>,
KeyRhs: FnMut(&RelValue<'a>) -> Result<ProductValue, ErrorVm>,
KeyLhs: FnMut(&RelValue<'a>) -> AlgebraicValue,
KeyRhs: FnMut(&RelValue<'a>) -> AlgebraicValue,
Rhs: RelOps<'a>,
{
Ok(JoinInner::new(head, self, with, key_lhs, key_rhs, predicate, project))
Expand Down Expand Up @@ -238,7 +238,7 @@ pub struct JoinInner<'a, Lhs, Rhs, KeyLhs, KeyRhs, Pred, Proj> {
pub(crate) key_rhs: KeyRhs,
pub(crate) predicate: Pred,
pub(crate) projection: Proj,
map: HashMap<ProductValue, Vec<RelValue<'a>>>,
map: HashMap<AlgebraicValue, Vec<RelValue<'a>>>,
filled_rhs: bool,
left: Option<RelValue<'a>>,
}
Expand Down Expand Up @@ -273,9 +273,9 @@ where
Lhs: RelOps<'a>,
Rhs: RelOps<'a>,
// TODO(Centril): consider using keys that aren't `ProductValue`s.
Centril marked this conversation as resolved.
Show resolved Hide resolved
KeyLhs: FnMut(&RelValue<'a>) -> Result<ProductValue, ErrorVm>,
KeyRhs: FnMut(&RelValue<'a>) -> Result<ProductValue, ErrorVm>,
Pred: FnMut(&RelValue<'a>, &RelValue<'a>) -> Result<bool, ErrorVm>,
KeyLhs: FnMut(&RelValue<'a>) -> AlgebraicValue,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't return &AlgebraicValue because it might be freshly allocated, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, it could at most return Cow<'_, AV> and we might want to attempt that in a follow up and see if it is beneficial.

KeyRhs: FnMut(&RelValue<'a>) -> AlgebraicValue,
Pred: FnMut(&RelValue<'a>, &RelValue<'a>) -> bool,
Proj: FnMut(RelValue<'a>, RelValue<'a>) -> RelValue<'a>,
{
fn head(&self) -> &Arc<Header> {
Expand All @@ -287,33 +287,29 @@ where
if !self.filled_rhs {
self.map = HashMap::with_capacity(self.rhs.row_count().min);
while let Some(row_rhs) = self.rhs.next()? {
let key_rhs = (self.key_rhs)(&row_rhs)?;
let key_rhs = (self.key_rhs)(&row_rhs);
self.map.entry(key_rhs).or_default().push(row_rhs);
}
self.filled_rhs = true;
}

loop {
// Consume a row in `Lhs` and project to `KeyLhs`.
let lhs = if let Some(left) = &self.left {
left.clone()
} else {
match self.lhs.next()? {
let lhs = match &self.left {
Some(left) => left,
None => match self.lhs.next()? {
Some(x) => self.left.insert(x),
None => return Ok(None),
Some(x) => {
self.left = Some(x.clone());
x
}
}
},
};
let k = (self.key_lhs)(&lhs)?;
let k = (self.key_lhs)(lhs);

// If we can relate `KeyLhs` and `KeyRhs`, we have candidate.
// If that candidate still has rhs elements, test against the predicate and yield.
if let Some(rvv) = self.map.get_mut(&k) {
if let Some(rhs) = rvv.pop() {
if (self.predicate)(&lhs, &rhs)? {
return Ok(Some((self.projection)(lhs, rhs)));
if (self.predicate)(lhs, &rhs) {
return Ok(Some((self.projection)(lhs.clone(), rhs)));
}
}
}
Expand Down
Loading