diff --git a/java/core/lance-jni/src/sql.rs b/java/core/lance-jni/src/sql.rs index b6882d8e9d0..e77addf37d9 100644 --- a/java/core/lance-jni/src/sql.rs +++ b/java/core/lance-jni/src/sql.rs @@ -77,67 +77,6 @@ fn inner_into_batch_records( Ok(()) } -#[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_SqlQuery_intoExplainPlan<'local>( - mut env: JNIEnv<'local>, - _class: JClass, - java_dataset: JObject, - sql: JString, - table_name: JObject, - with_row_id: jboolean, - with_row_addr: jboolean, - verbose: jboolean, - analyze: jboolean, -) -> JString<'local> { - ok_or_throw_with_return!( - env, - inner_into_explain_plan( - &mut env, - java_dataset, - sql, - table_name, - with_row_id, - with_row_addr, - verbose, - analyze - ) - .map_err(|e| Error::io_error(e.to_string())), - JString::default() - ) -} - -#[allow(clippy::too_many_arguments)] -fn inner_into_explain_plan<'local>( - env: &mut JNIEnv<'local>, - java_dataset: JObject, - sql: JString, - table_name: JObject, - with_row_id: jboolean, - with_row_addr: jboolean, - verbose: jboolean, - analyze: jboolean, -) -> Result> { - let builder = sql_builder( - env, - java_dataset, - sql, - table_name, - with_row_id, - with_row_addr, - )?; - - let explain = RT.block_on(async move { - builder - .build() - .await - .unwrap() - .into_explain_plan(verbose == JNI_TRUE, analyze == JNI_TRUE) - .await - })?; - - Ok(env.new_string(explain)?) -} - fn sql_builder( env: &mut JNIEnv, java_dataset: JObject, diff --git a/java/core/src/main/java/com/lancedb/lance/SqlQuery.java b/java/core/src/main/java/com/lancedb/lance/SqlQuery.java index 9321f322b18..e35ed0784c1 100644 --- a/java/core/src/main/java/com/lancedb/lance/SqlQuery.java +++ b/java/core/src/main/java/com/lancedb/lance/SqlQuery.java @@ -65,19 +65,4 @@ private static native void intoBatchRecords( boolean withRowAddr, long streamAddress) throws IOException; - - public String intoExplainPlan(boolean verbose, boolean analyze) throws IOException { - return intoExplainPlan( - dataset, sql, Optional.ofNullable(table), withRowId, withRowAddr, verbose, analyze); - } - - private static native String intoExplainPlan( - Dataset dataset, - String sql, - Optional tableName, - boolean withRowId, - boolean withRowAddr, - boolean verbose, - boolean analyze) - throws IOException; } diff --git a/java/core/src/test/java/com/lancedb/lance/SqlQueryTest.java b/java/core/src/test/java/com/lancedb/lance/SqlQueryTest.java index 09d9f56fb35..28ff0728a5e 100644 --- a/java/core/src/test/java/com/lancedb/lance/SqlQueryTest.java +++ b/java/core/src/test/java/com/lancedb/lance/SqlQueryTest.java @@ -108,12 +108,4 @@ public void testToRecordBatches() throws IOException { reader.getVectorSchemaRoot().getSchema().toString()); reader.close(); } - - @Test - public void testToExplainPlan() throws IOException { - String plan = - dataset.sql("select sum(id) from " + NAME).tableName(NAME).intoExplainPlan(true, false); - - Assertions.assertTrue(plan.contains("Aggregate") || plan.contains("SUM")); - } } diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index ce1ada792b0..2ce9003bd68 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -3762,8 +3762,6 @@ def test_dataset_sql(tmp_path: Path): ds = lance.write_dataset(table, tmp_path / "test") query = ds.sql("SELECT * FROM test WHERE id > 1").table_name("test").build() - explain_plan = query.explain_plan(verbose=True) - assert "Filter" in explain_plan result = query.to_batch_records() expected = pa.table({"id": [2, 3], "value": ["b", "c"]}) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 65647ec81ec..7e0f646a0fe 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -2082,19 +2082,6 @@ impl SqlQuery { Box::new(LanceReader::from_stream(dataset_stream)); Python::with_gil(|py| reader.into_pyarrow(py)) } - - #[pyo3(signature = (verbose=false, analyze=false))] - fn explain_plan(&self, verbose: bool, analyze: bool) -> PyResult { - let builder = self.builder.clone(); - let plan = RT - .block_on(None, async move { - let query = builder.build().await?; - query.into_explain_plan(verbose, analyze).await - }) - .map_err(|e| PyValueError::new_err(e.to_string()))? - .map_err(|e| PyValueError::new_err(e.to_string()))?; - Ok(plan) - } } #[pyclass(name = "SqlQueryBuilder", module = "_lib", subclass)] diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index b6612b20e5e..4846cabc40b 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -3,7 +3,7 @@ use crate::datafusion::LanceTableProvider; use crate::Dataset; -use arrow_array::{Array, RecordBatch, StringArray}; +use arrow_array::RecordBatch; use datafusion::dataframe::DataFrame; use datafusion::execution::SendableRecordBatchStream; use datafusion::prelude::SessionContext; @@ -107,33 +107,11 @@ impl SqlQuery { pub fn into_dataframe(self) -> DataFrame { self.dataframe } - - pub async fn into_explain_plan( - self, - verbose: bool, - analyze: bool, - ) -> lance_core::Result { - let explained_df = self.dataframe.explain(verbose, analyze)?; - let batches = explained_df.collect().await?; - let mut lines = Vec::new(); - for batch in &batches { - let column = batch.column(0); - let array = column - .as_any() - .downcast_ref::() - .expect("Expected StringArray in 'plan' column for DataFrame.explain"); - for i in 0..array.len() { - lines.push(array.value(i).to_string()); - } - } - - Ok(lines.join("\n")) - } } #[cfg(test)] mod tests { - use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; + use crate::utils::test::{assert_string_matches, DatagenExt, FragmentCount, FragmentRowCount}; use all_asserts::assert_true; use arrow_array::cast::AsArray; use arrow_array::types::{Int32Type, Int64Type, UInt64Type}; @@ -233,27 +211,74 @@ mod tests { } #[tokio::test] - async fn test_sql_explain_plan() { + async fn test_explain() { let mut ds = gen_batch() .col("x", array::step::()) .col("y", array::step_custom::(0, 2)) .into_dataset( - "memory://test_sql_explain_plan", - FragmentCount::from(2), - FragmentRowCount::from(5), + "memory://test_sql_dataset", + FragmentCount::from(10), + FragmentRowCount::from(10), ) .await .unwrap(); - let builder = ds - .sql("SELECT SUM(x) FROM foo WHERE y > 2") + let results = ds + .sql("EXPLAIN SELECT * FROM foo where y >= 100") .table_name("foo") .build() .await + .unwrap() + .into_batch_records() + .await .unwrap(); + let results = results.into_iter().next().unwrap(); + + let plan = format!("{:?}", results); + let expected_pattern = r#"...columns: [StringArray +[ + "logical_plan", + "physical_plan", +], StringArray +[ + "TableScan: foo projection=[x, y], full_filters=[foo.y >= Int32(100)]", + "ProjectionExec: expr=[x@0 as x, y@1 as y]\n LanceRead: uri=test_sql_dataset/data, projection=[x, y], num_fragments=10, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=y >= Int32(100), refine_filter=y >= Int32(100)\n", +]], row_count: 2 }"#; + assert_string_matches(&plan, expected_pattern).unwrap(); + } - let plan = builder.into_explain_plan(true, false).await.unwrap(); + #[tokio::test] + async fn test_analyze() { + let mut ds = gen_batch() + .col("x", array::step::()) + .col("y", array::step_custom::(0, 2)) + .into_dataset( + "memory://test_sql_dataset", + FragmentCount::from(10), + FragmentRowCount::from(10), + ) + .await + .unwrap(); + + let results = ds + .sql("EXPLAIN ANALYZE SELECT * FROM foo where y >= 100") + .table_name("foo") + .build() + .await + .unwrap() + .into_batch_records() + .await + .unwrap(); + let results = results.into_iter().next().unwrap(); - assert!(plan.contains("Aggregate") || plan.contains("SUM")); + let plan = format!("{:?}", results); + let expected_pattern = r#"...columns: [StringArray +[ + "Plan with Metrics", +], StringArray +[ + "ProjectionExec: expr=[x@0 as x, y@1 as y], metrics=[output_rows=50, elapsed_compute=...]\n LanceRead: uri=test_sql_dataset/data, projection=[x, y], num_fragments=..., range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=y >= Int32(100), refine_filter=y >= Int32(100), metrics=[output_rows=..., elapsed_compute=..., bytes_read=..., fragments_scanned=..., iops=..., ranges_scanned=..., requests=..., rows_scanned=..., task_wait_time=...]\n", +]], row_count: 1 }"#; + assert_string_matches(&plan, expected_pattern).unwrap(); } }