Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 0 additions & 61 deletions java/core/lance-jni/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,67 +77,6 @@ fn inner_into_batch_records(
Ok(())
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_SqlQuery_intoExplainPlan<'local>(
mut env: JNIEnv<'local>,
_class: JClass,
java_dataset: JObject,
sql: JString,
table_name: JObject,
with_row_id: jboolean,
with_row_addr: jboolean,
verbose: jboolean,
analyze: jboolean,
) -> JString<'local> {
ok_or_throw_with_return!(
env,
inner_into_explain_plan(
&mut env,
java_dataset,
sql,
table_name,
with_row_id,
with_row_addr,
verbose,
analyze
)
.map_err(|e| Error::io_error(e.to_string())),
JString::default()
)
}

#[allow(clippy::too_many_arguments)]
fn inner_into_explain_plan<'local>(
env: &mut JNIEnv<'local>,
java_dataset: JObject,
sql: JString,
table_name: JObject,
with_row_id: jboolean,
with_row_addr: jboolean,
verbose: jboolean,
analyze: jboolean,
) -> Result<JString<'local>> {
let builder = sql_builder(
env,
java_dataset,
sql,
table_name,
with_row_id,
with_row_addr,
)?;

let explain = RT.block_on(async move {
builder
.build()
.await
.unwrap()
.into_explain_plan(verbose == JNI_TRUE, analyze == JNI_TRUE)
.await
})?;

Ok(env.new_string(explain)?)
}

fn sql_builder(
env: &mut JNIEnv,
java_dataset: JObject,
Expand Down
15 changes: 0 additions & 15 deletions java/core/src/main/java/com/lancedb/lance/SqlQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,4 @@ private static native void intoBatchRecords(
boolean withRowAddr,
long streamAddress)
throws IOException;

public String intoExplainPlan(boolean verbose, boolean analyze) throws IOException {
return intoExplainPlan(
dataset, sql, Optional.ofNullable(table), withRowId, withRowAddr, verbose, analyze);
}

private static native String intoExplainPlan(
Dataset dataset,
String sql,
Optional<String> tableName,
boolean withRowId,
boolean withRowAddr,
boolean verbose,
boolean analyze)
throws IOException;
}
8 changes: 0 additions & 8 deletions java/core/src/test/java/com/lancedb/lance/SqlQueryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,4 @@ public void testToRecordBatches() throws IOException {
reader.getVectorSchemaRoot().getSchema().toString());
reader.close();
}

@Test
public void testToExplainPlan() throws IOException {
String plan =
dataset.sql("select sum(id) from " + NAME).tableName(NAME).intoExplainPlan(true, false);

Assertions.assertTrue(plan.contains("Aggregate") || plan.contains("SUM"));
}
}
2 changes: 0 additions & 2 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3762,8 +3762,6 @@ def test_dataset_sql(tmp_path: Path):
ds = lance.write_dataset(table, tmp_path / "test")

query = ds.sql("SELECT * FROM test WHERE id > 1").table_name("test").build()
explain_plan = query.explain_plan(verbose=True)
assert "Filter" in explain_plan

result = query.to_batch_records()
expected = pa.table({"id": [2, 3], "value": ["b", "c"]})
Expand Down
13 changes: 0 additions & 13 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2082,19 +2082,6 @@ impl SqlQuery {
Box::new(LanceReader::from_stream(dataset_stream));
Python::with_gil(|py| reader.into_pyarrow(py))
}

#[pyo3(signature = (verbose=false, analyze=false))]
fn explain_plan(&self, verbose: bool, analyze: bool) -> PyResult<String> {
let builder = self.builder.clone();
let plan = RT
.block_on(None, async move {
let query = builder.build().await?;
query.into_explain_plan(verbose, analyze).await
})
.map_err(|e| PyValueError::new_err(e.to_string()))?
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(plan)
}
}

#[pyclass(name = "SqlQueryBuilder", module = "_lib", subclass)]
Expand Down
89 changes: 57 additions & 32 deletions rust/lance/src/dataset/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::datafusion::LanceTableProvider;
use crate::Dataset;
use arrow_array::{Array, RecordBatch, StringArray};
use arrow_array::RecordBatch;
use datafusion::dataframe::DataFrame;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -107,33 +107,11 @@ impl SqlQuery {
pub fn into_dataframe(self) -> DataFrame {
self.dataframe
}

pub async fn into_explain_plan(
self,
verbose: bool,
analyze: bool,
) -> lance_core::Result<String> {
let explained_df = self.dataframe.explain(verbose, analyze)?;
let batches = explained_df.collect().await?;
let mut lines = Vec::new();
for batch in &batches {
let column = batch.column(0);
let array = column
.as_any()
.downcast_ref::<StringArray>()
.expect("Expected StringArray in 'plan' column for DataFrame.explain");
for i in 0..array.len() {
lines.push(array.value(i).to_string());
}
}

Ok(lines.join("\n"))
}
}

#[cfg(test)]
mod tests {
use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount};
use crate::utils::test::{assert_string_matches, DatagenExt, FragmentCount, FragmentRowCount};
use all_asserts::assert_true;
use arrow_array::cast::AsArray;
use arrow_array::types::{Int32Type, Int64Type, UInt64Type};
Expand Down Expand Up @@ -233,27 +211,74 @@ mod tests {
}

#[tokio::test]
async fn test_sql_explain_plan() {
async fn test_explain() {
let mut ds = gen_batch()
.col("x", array::step::<Int32Type>())
.col("y", array::step_custom::<Int32Type>(0, 2))
.into_dataset(
"memory://test_sql_explain_plan",
FragmentCount::from(2),
FragmentRowCount::from(5),
"memory://test_sql_dataset",
FragmentCount::from(10),
FragmentRowCount::from(10),
)
.await
.unwrap();

let builder = ds
.sql("SELECT SUM(x) FROM foo WHERE y > 2")
let results = ds
.sql("EXPLAIN SELECT * FROM foo where y >= 100")
.table_name("foo")
.build()
.await
.unwrap()
.into_batch_records()
.await
.unwrap();
let results = results.into_iter().next().unwrap();

let plan = format!("{:?}", results);
let expected_pattern = r#"...columns: [StringArray
[
"logical_plan",
"physical_plan",
], StringArray
[
"TableScan: foo projection=[x, y], full_filters=[foo.y >= Int32(100)]",
"ProjectionExec: expr=[x@0 as x, y@1 as y]\n LanceRead: uri=test_sql_dataset/data, projection=[x, y], num_fragments=10, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=y >= Int32(100), refine_filter=y >= Int32(100)\n",
]], row_count: 2 }"#;
assert_string_matches(&plan, expected_pattern).unwrap();
}

let plan = builder.into_explain_plan(true, false).await.unwrap();
#[tokio::test]
async fn test_analyze() {
let mut ds = gen_batch()
.col("x", array::step::<Int32Type>())
.col("y", array::step_custom::<Int32Type>(0, 2))
.into_dataset(
"memory://test_sql_dataset",
FragmentCount::from(10),
FragmentRowCount::from(10),
)
.await
.unwrap();

let results = ds
.sql("EXPLAIN ANALYZE SELECT * FROM foo where y >= 100")
.table_name("foo")
.build()
.await
.unwrap()
.into_batch_records()
.await
.unwrap();
let results = results.into_iter().next().unwrap();

assert!(plan.contains("Aggregate") || plan.contains("SUM"));
let plan = format!("{:?}", results);
let expected_pattern = r#"...columns: [StringArray
[
"Plan with Metrics",
], StringArray
[
"ProjectionExec: expr=[x@0 as x, y@1 as y], metrics=[output_rows=50, elapsed_compute=...]\n LanceRead: uri=test_sql_dataset/data, projection=[x, y], num_fragments=..., range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=y >= Int32(100), refine_filter=y >= Int32(100), metrics=[output_rows=..., elapsed_compute=..., bytes_read=..., fragments_scanned=..., iops=..., ranges_scanned=..., requests=..., rows_scanned=..., task_wait_time=...]\n",
]], row_count: 1 }"#;
assert_string_matches(&plan, expected_pattern).unwrap();
}
}
Loading