Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 0 additions & 61 deletions java/core/lance-jni/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,67 +77,6 @@ fn inner_into_batch_records(
Ok(())
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_SqlQuery_intoExplainPlan<'local>(
mut env: JNIEnv<'local>,
_class: JClass,
java_dataset: JObject,
sql: JString,
table_name: JObject,
with_row_id: jboolean,
with_row_addr: jboolean,
verbose: jboolean,
analyze: jboolean,
) -> JString<'local> {
ok_or_throw_with_return!(
env,
inner_into_explain_plan(
&mut env,
java_dataset,
sql,
table_name,
with_row_id,
with_row_addr,
verbose,
analyze
)
.map_err(|e| Error::io_error(e.to_string())),
JString::default()
)
}

#[allow(clippy::too_many_arguments)]
fn inner_into_explain_plan<'local>(
env: &mut JNIEnv<'local>,
java_dataset: JObject,
sql: JString,
table_name: JObject,
with_row_id: jboolean,
with_row_addr: jboolean,
verbose: jboolean,
analyze: jboolean,
) -> Result<JString<'local>> {
let builder = sql_builder(
env,
java_dataset,
sql,
table_name,
with_row_id,
with_row_addr,
)?;

let explain = RT.block_on(async move {
builder
.build()
.await
.unwrap()
.into_explain_plan(verbose == JNI_TRUE, analyze == JNI_TRUE)
.await
})?;

Ok(env.new_string(explain)?)
}

fn sql_builder(
env: &mut JNIEnv,
java_dataset: JObject,
Expand Down
15 changes: 0 additions & 15 deletions java/core/src/main/java/com/lancedb/lance/SqlQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,4 @@ private static native void intoBatchRecords(
boolean withRowAddr,
long streamAddress)
throws IOException;

public String intoExplainPlan(boolean verbose, boolean analyze) throws IOException {
return intoExplainPlan(
dataset, sql, Optional.ofNullable(table), withRowId, withRowAddr, verbose, analyze);
}

private static native String intoExplainPlan(
Dataset dataset,
String sql,
Optional<String> tableName,
boolean withRowId,
boolean withRowAddr,
boolean verbose,
boolean analyze)
throws IOException;
}
23 changes: 21 additions & 2 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3757,13 +3757,32 @@ def test_metadata_cache_size(tmp_path):
assert zero_cache_size < default_size


def test_dataset_sql_explain_analyze(tmp_path: Path):
table = pa.table({"id": [1, 2, 3], "value": ["a", "b", "c"]})
ds = lance.write_dataset(table, tmp_path / "test")

query = ds.sql("EXPLAIN SELECT * FROM test WHERE id > 1").table_name("test").build()
batch_records = query.to_batch_records()
explain_plan = pa.Table.from_batches(batch_records).to_pandas().to_string()
assert any(k in explain_plan for k in ("Filter", "full_filter", "refine_filter")), (
Comment thread
wojiaodoubao marked this conversation as resolved.
Outdated
explain_plan
)

query = (
Comment thread
wojiaodoubao marked this conversation as resolved.
Outdated
ds.sql("EXPLAIN ANALYZE SELECT * FROM test WHERE id > 1")
.table_name("test")
.build()
)
batch_records = query.to_batch_records()
analyze_plan = pa.Table.from_batches(batch_records).to_pandas().to_string()
assert "Metrics" in analyze_plan


def test_dataset_sql(tmp_path: Path):
table = pa.table({"id": [1, 2, 3], "value": ["a", "b", "c"]})
ds = lance.write_dataset(table, tmp_path / "test")

query = ds.sql("SELECT * FROM test WHERE id > 1").table_name("test").build()
explain_plan = query.explain_plan(verbose=True)
assert "Filter" in explain_plan

result = query.to_batch_records()
expected = pa.table({"id": [2, 3], "value": ["b", "c"]})
Expand Down
13 changes: 0 additions & 13 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2082,19 +2082,6 @@ impl SqlQuery {
Box::new(LanceReader::from_stream(dataset_stream));
Python::with_gil(|py| reader.into_pyarrow(py))
}

#[pyo3(signature = (verbose=false, analyze=false))]
fn explain_plan(&self, verbose: bool, analyze: bool) -> PyResult<String> {
let builder = self.builder.clone();
let plan = RT
.block_on(None, async move {
let query = builder.build().await?;
query.into_explain_plan(verbose, analyze).await
})
.map_err(|e| PyValueError::new_err(e.to_string()))?
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(plan)
}
}

#[pyclass(name = "SqlQueryBuilder", module = "_lib", subclass)]
Expand Down
49 changes: 1 addition & 48 deletions rust/lance/src/dataset/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::datafusion::LanceTableProvider;
use crate::Dataset;
use arrow_array::{Array, RecordBatch, StringArray};
use arrow_array::RecordBatch;
use datafusion::dataframe::DataFrame;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -107,28 +107,6 @@ impl SqlQuery {
pub fn into_dataframe(self) -> DataFrame {
self.dataframe
}

pub async fn into_explain_plan(
self,
verbose: bool,
analyze: bool,
) -> lance_core::Result<String> {
let explained_df = self.dataframe.explain(verbose, analyze)?;
let batches = explained_df.collect().await?;
let mut lines = Vec::new();
for batch in &batches {
let column = batch.column(0);
let array = column
.as_any()
.downcast_ref::<StringArray>()
.expect("Expected StringArray in 'plan' column for DataFrame.explain");
for i in 0..array.len() {
lines.push(array.value(i).to_string());
}
}

Ok(lines.join("\n"))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -231,29 +209,4 @@ mod tests {
pretty_assertions::assert_eq!(results.num_rows(), 1);
pretty_assertions::assert_eq!(results.column(0).as_primitive::<Int64Type>().value(0), 50);
}

#[tokio::test]
async fn test_sql_explain_plan() {
let mut ds = gen_batch()
.col("x", array::step::<Int32Type>())
.col("y", array::step_custom::<Int32Type>(0, 2))
.into_dataset(
"memory://test_sql_explain_plan",
FragmentCount::from(2),
FragmentRowCount::from(5),
)
.await
.unwrap();

let builder = ds
.sql("SELECT SUM(x) FROM foo WHERE y > 2")
.table_name("foo")
.build()
.await
.unwrap();

let plan = builder.into_explain_plan(true, false).await.unwrap();

assert!(plan.contains("Aggregate") || plan.contains("SUM"));
}
}
Loading