diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index 04aec9d77d58a..88208b5507ae0 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -134,6 +134,16 @@ impl PartitionedFile { self.range = Some(FileRange { start, end }); self } + + /// Update the user defined extensions for this file. You can use this field + /// to pass reader specific information. + pub fn with_extensions( + mut self, + extensions: Arc, + ) -> Self { + self.extensions = Some(extensions); + self + } } impl From for PartitionedFile { diff --git a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs index c59459ba6172e..aebbd155158fb 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion_common::{internal_err, Result}; use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::file::metadata::RowGroupMetaData; @@ -182,6 +183,11 @@ impl ParquetAccessPlan { /// is returned for *all* the rows in the row groups that are not skipped. /// Thus it includes a `Select` selection for any [`RowGroupAccess::Scan`]. /// + /// # Errors + /// + /// Returns an error if the specified row selection does not specify + /// the same number of rows as in `row_group_metadata`. + /// /// # Example: No Selections /// /// Given an access plan like this @@ -228,7 +234,7 @@ impl ParquetAccessPlan { pub fn into_overall_row_selection( self, row_group_meta_data: &[RowGroupMetaData], - ) -> Option { + ) -> Result> { assert_eq!(row_group_meta_data.len(), self.row_groups.len()); // Intuition: entire row groups are filtered out using // `row_group_indexes` which come from Skip and Scan. An overall @@ -239,7 +245,32 @@ impl ParquetAccessPlan { .iter() .any(|rg| matches!(rg, RowGroupAccess::Selection(_))) { - return None; + return Ok(None); + } + + // validate all Selections + for (idx, (rg, rg_meta)) in self + .row_groups + .iter() + .zip(row_group_meta_data.iter()) + .enumerate() + { + let RowGroupAccess::Selection(selection) = rg else { + continue; + }; + let rows_in_selection = selection + .iter() + .map(|selection| selection.row_count) + .sum::(); + + let row_group_row_count = rg_meta.num_rows(); + if rows_in_selection as i64 != row_group_row_count { + return internal_err!( + "Invalid ParquetAccessPlan Selection. Row group {idx} has {row_group_row_count} rows \ + but selection only specifies {rows_in_selection} rows. \ + Selection: {selection:?}" + ); + } } let total_selection: RowSelection = self @@ -261,7 +292,7 @@ impl ParquetAccessPlan { }) .collect(); - Some(total_selection) + Ok(Some(total_selection)) } /// Return an iterator over the row group indexes that should be scanned @@ -305,6 +336,7 @@ impl ParquetAccessPlan { #[cfg(test)] mod test { use super::*; + use datafusion_common::assert_contains; use parquet::basic::LogicalType; use parquet::file::metadata::ColumnChunkMetaData; use parquet::schema::types::{SchemaDescPtr, SchemaDescriptor}; @@ -320,7 +352,9 @@ mod test { ]); let row_group_indexes = access_plan.row_group_indexes(); - let row_selection = access_plan.into_overall_row_selection(row_group_metadata()); + let row_selection = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap(); // scan all row groups, no selection assert_eq!(row_group_indexes, vec![0, 1, 2, 3]); @@ -337,7 +371,9 @@ mod test { ]); let row_group_indexes = access_plan.row_group_indexes(); - let row_selection = access_plan.into_overall_row_selection(row_group_metadata()); + let row_selection = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap(); // skip all row groups, no selection assert_eq!(row_group_indexes, vec![] as Vec); @@ -348,14 +384,22 @@ mod test { let access_plan = ParquetAccessPlan::new(vec![ RowGroupAccess::Scan, RowGroupAccess::Selection( - vec![RowSelector::select(5), RowSelector::skip(7)].into(), + // select / skip all 20 rows in row group 1 + vec![ + RowSelector::select(5), + RowSelector::skip(7), + RowSelector::select(8), + ] + .into(), ), RowGroupAccess::Skip, RowGroupAccess::Skip, ]); let row_group_indexes = access_plan.row_group_indexes(); - let row_selection = access_plan.into_overall_row_selection(row_group_metadata()); + let row_selection = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap(); assert_eq!(row_group_indexes, vec![0, 1]); assert_eq!( @@ -366,7 +410,8 @@ mod test { RowSelector::select(10), // selectors from the second row group RowSelector::select(5), - RowSelector::skip(7) + RowSelector::skip(7), + RowSelector::select(8) ] .into() ) @@ -379,13 +424,21 @@ mod test { RowGroupAccess::Skip, RowGroupAccess::Scan, RowGroupAccess::Selection( - vec![RowSelector::select(5), RowSelector::skip(7)].into(), + // specify all 30 rows in row group 1 + vec![ + RowSelector::select(5), + RowSelector::skip(7), + RowSelector::select(18), + ] + .into(), ), RowGroupAccess::Scan, ]); let row_group_indexes = access_plan.row_group_indexes(); - let row_selection = access_plan.into_overall_row_selection(row_group_metadata()); + let row_selection = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap(); assert_eq!(row_group_indexes, vec![1, 2, 3]); assert_eq!( @@ -397,6 +450,7 @@ mod test { // selectors from the third row group RowSelector::select(5), RowSelector::skip(7), + RowSelector::select(18), // select the entire fourth row group RowSelector::select(40), ] @@ -405,6 +459,53 @@ mod test { ); } + #[test] + fn test_invalid_too_few() { + let access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + // select 12 rows, but row group 1 has 20 + RowGroupAccess::Selection( + vec![RowSelector::select(5), RowSelector::skip(7)].into(), + ), + RowGroupAccess::Scan, + RowGroupAccess::Scan, + ]); + + let row_group_indexes = access_plan.row_group_indexes(); + let err = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap_err() + .to_string(); + assert_eq!(row_group_indexes, vec![0, 1, 2, 3]); + assert_contains!(err, "Internal error: Invalid ParquetAccessPlan Selection. Row group 1 has 20 rows but selection only specifies 12 rows"); + } + + #[test] + fn test_invalid_too_many() { + let access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + // select 22 rows, but row group 1 has only 20 + RowGroupAccess::Selection( + vec![ + RowSelector::select(10), + RowSelector::skip(2), + RowSelector::select(10), + ] + .into(), + ), + RowGroupAccess::Scan, + RowGroupAccess::Scan, + ]); + + let row_group_indexes = access_plan.row_group_indexes(); + let err = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap_err() + .to_string(); + assert_eq!(row_group_indexes, vec![0, 1, 2, 3]); + assert_contains!(err, "Invalid ParquetAccessPlan Selection. Row group 1 has 20 rows but selection only specifies 22 rows"); + } + static ROW_GROUP_METADATA: OnceLock> = OnceLock::new(); /// [`RowGroupMetaData`] that returns 4 row groups with 10, 20, 30, 40 rows diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index a5047e487eee6..46112e5380e13 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -28,6 +28,7 @@ use crate::datasource::physical_plan::{ use crate::datasource::schema_adapter::SchemaAdapterFactory; use crate::physical_optimizer::pruning::PruningPredicate; use arrow_schema::{ArrowError, SchemaRef}; +use datafusion_common::{exec_err, Result}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use futures::{StreamExt, TryStreamExt}; @@ -60,11 +61,10 @@ pub(super) struct ParquetOpener { impl FileOpener for ParquetOpener { fn open(&self, file_meta: FileMeta) -> datafusion_common::Result { let file_range = file_meta.range.clone(); - let file_metrics = ParquetFileMetrics::new( - self.partition_index, - file_meta.location().as_ref(), - &self.metrics, - ); + let extensions = file_meta.extensions.clone(); + let file_name = file_meta.location().to_string(); + let file_metrics = + ParquetFileMetrics::new(self.partition_index, &file_name, &self.metrics); let reader: Box = self.parquet_file_reader_factory.create_reader( @@ -139,7 +139,8 @@ impl FileOpener for ParquetOpener { let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); let rg_metadata = file_metadata.row_groups(); // track which row groups to actually read - let access_plan = ParquetAccessPlan::new_all(rg_metadata.len()); + let access_plan = + create_initial_plan(&file_name, extensions, rg_metadata.len())?; let mut row_groups = RowGroupAccessPlanFilter::new(access_plan); // if there is a range restricting what parts of the file to read if let Some(range) = file_range.as_ref() { @@ -186,7 +187,7 @@ impl FileOpener for ParquetOpener { let row_group_indexes = access_plan.row_group_indexes(); if let Some(row_selection) = - access_plan.into_overall_row_selection(rg_metadata) + access_plan.into_overall_row_selection(rg_metadata)? { builder = builder.with_row_selection(row_selection); } @@ -212,3 +213,34 @@ impl FileOpener for ParquetOpener { })) } } + +/// Return the initial [`ParquetAccessPlan`] +/// +/// If the user has supplied one as an extension, use that +/// otherwise return a plan that scans all row groups +/// +/// Returns an error is an invalid `ParquetAccessPlan` is provided +/// +/// Note: path is only used for error messages +fn create_initial_plan( + file_name: &str, + extensions: Option>, + row_group_count: usize, +) -> Result { + if let Some(extensions) = extensions { + if let Some(access_plan) = extensions.downcast_ref::() { + let plan_len = access_plan.len(); + if plan_len != row_group_count { + return exec_err!( + "Invalid ParquetAccessPlan for {file_name}. Specified {plan_len} row groups, but file has {row_group_count}" + ); + } + + // check row group count matches the plan + return Ok(access_plan.clone()); + } + } + + // default to scanning all row groups + Ok(ParquetAccessPlan::new_all(row_group_count)) +} diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs new file mode 100644 index 0000000000000..69fee29551b60 --- /dev/null +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -0,0 +1,342 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for passing user provided [`ParquetAccessPlan`]` to `ParquetExec`]` +use crate::parquet::utils::MetricsFinder; +use crate::parquet::{create_data_batch, Scenario}; +use arrow_schema::SchemaRef; +use datafusion::common::Result; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::parquet::{ParquetAccessPlan, RowGroupAccess}; +use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +use datafusion::prelude::SessionContext; +use datafusion_common::assert_contains; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::ExecutionPlan; +use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::WriterProperties; +use std::sync::{Arc, OnceLock}; +use tempfile::NamedTempFile; + +#[tokio::test] +async fn none() { + // no user defined plan + Test { + access_plan: None, + expected_rows: 10, + } + .run_success() + .await; +} + +#[tokio::test] +async fn scan_all() { + let parquet_metrics = Test { + access_plan: Some(ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + RowGroupAccess::Scan, + ])), + expected_rows: 10, + } + .run_success() + .await; + + // Verify that some bytes were read + let bytes_scanned = metric_value(&parquet_metrics, "bytes_scanned").unwrap(); + assert_ne!(bytes_scanned, 0, "metrics : {parquet_metrics:#?}",); +} + +#[tokio::test] +async fn skip_all() { + let parquet_metrics = Test { + access_plan: Some(ParquetAccessPlan::new(vec![ + RowGroupAccess::Skip, + RowGroupAccess::Skip, + ])), + expected_rows: 0, + } + .run_success() + .await; + + // Verify that skipping all row groups skips reading any data at all + let bytes_scanned = metric_value(&parquet_metrics, "bytes_scanned").unwrap(); + assert_eq!(bytes_scanned, 0, "metrics : {parquet_metrics:#?}",); +} + +#[tokio::test] +async fn skip_one_row_group() { + let plans = vec![ + ParquetAccessPlan::new(vec![RowGroupAccess::Scan, RowGroupAccess::Skip]), + ParquetAccessPlan::new(vec![RowGroupAccess::Skip, RowGroupAccess::Scan]), + ]; + + for access_plan in plans { + Test { + access_plan: Some(access_plan), + expected_rows: 5, + } + .run_success() + .await; + } +} + +#[tokio::test] +async fn selection_scan() { + let plans = vec![ + ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + RowGroupAccess::Selection(select_one_row()), + ]), + ParquetAccessPlan::new(vec![ + RowGroupAccess::Selection(select_one_row()), + RowGroupAccess::Scan, + ]), + ]; + + for access_plan in plans { + Test { + access_plan: Some(access_plan), + expected_rows: 6, + } + .run_success() + .await; + } +} + +#[tokio::test] +async fn skip_scan() { + let plans = vec![ + // skip one row group, scan the toehr + ParquetAccessPlan::new(vec![ + RowGroupAccess::Skip, + RowGroupAccess::Selection(select_one_row()), + ]), + ParquetAccessPlan::new(vec![ + RowGroupAccess::Selection(select_one_row()), + RowGroupAccess::Skip, + ]), + ]; + + for access_plan in plans { + Test { + access_plan: Some(access_plan), + expected_rows: 1, + } + .run_success() + .await; + } +} + +#[tokio::test] +async fn two_selections() { + let plans = vec![ + ParquetAccessPlan::new(vec![ + RowGroupAccess::Selection(select_one_row()), + RowGroupAccess::Selection(select_two_rows()), + ]), + ParquetAccessPlan::new(vec![ + RowGroupAccess::Selection(select_two_rows()), + RowGroupAccess::Selection(select_one_row()), + ]), + ]; + + for access_plan in plans { + Test { + access_plan: Some(access_plan), + expected_rows: 3, + } + .run_success() + .await; + } +} + +#[tokio::test] +async fn bad_row_groups() { + let err = Test { + access_plan: Some(ParquetAccessPlan::new(vec![ + // file has only 2 row groups, but specify 3 + RowGroupAccess::Scan, + RowGroupAccess::Skip, + RowGroupAccess::Scan, + ])), + expected_rows: 0, + } + .run() + .await + .unwrap_err(); + let err_string = err.to_string(); + assert_contains!(&err_string, "Invalid ParquetAccessPlan"); + assert_contains!(&err_string, "Specified 3 row groups, but file has 2"); +} + +#[tokio::test] +async fn bad_selection() { + let err = Test { + access_plan: Some(ParquetAccessPlan::new(vec![ + // specify fewer rows than are actually in the row group + RowGroupAccess::Selection(RowSelection::from(vec![ + RowSelector::skip(1), + RowSelector::select(3), + ])), + RowGroupAccess::Skip, + ])), + // expects that we hit an error, this should not be run + expected_rows: 10000, + } + .run() + .await + .unwrap_err(); + let err_string = err.to_string(); + assert_contains!(&err_string, "Internal error: Invalid ParquetAccessPlan Selection. Row group 0 has 5 rows but selection only specifies 4 rows"); +} + +/// Return a RowSelection of 1 rows from a row group of 5 rows +fn select_one_row() -> RowSelection { + RowSelection::from(vec![ + RowSelector::skip(2), + RowSelector::select(1), + RowSelector::skip(2), + ]) +} +/// Return a RowSelection of 2 rows from a row group of 5 rows +fn select_two_rows() -> RowSelection { + RowSelection::from(vec![ + RowSelector::skip(1), + RowSelector::select(1), + RowSelector::skip(1), + RowSelector::select(1), + RowSelector::skip(1), + ]) +} + +/// Test for passing user defined ParquetAccessPlans: +/// +/// 1. Creates a parquet file with 10 rows: rg0: 5 rows, rg1: 5 rows +/// 2. Reads the parquet file with an optional user provided access plan +/// 3. Verifies that the expected number of rows is read +#[derive(Debug)] +struct Test { + access_plan: Option, + expected_rows: usize, +} + +impl Test { + /// Runs the test case, panic'ing on error. + /// + /// Returns the `MetricsSet` from the ParqeutExec + async fn run_success(self) -> MetricsSet { + self.run().await.unwrap() + } + + async fn run(self) -> Result { + let Self { + access_plan, + expected_rows, + } = self; + + let TestData { + temp_file: _, + schema, + file_name, + file_size, + } = get_test_data(); + + let mut partitioned_file = PartitionedFile::new(file_name, *file_size); + + // add the access plan, if any, as an extension + if let Some(access_plan) = access_plan { + partitioned_file = partitioned_file.with_extensions(Arc::new(access_plan)); + } + + // Create a ParquetExec to read the file + let object_store_url = ObjectStoreUrl::local_filesystem(); + let config = FileScanConfig::new(object_store_url, schema.clone()) + .with_file(partitioned_file); + let plan: Arc = ParquetExec::builder(config).build_arc(); + + // run the ParquetExec and collect the results + let ctx = SessionContext::new(); + let results = + datafusion::physical_plan::collect(Arc::clone(&plan), ctx.task_ctx()).await?; + + // calculate the total number of rows that came out + let total_rows = results.iter().map(|b| b.num_rows()).sum::(); + assert_eq!(total_rows, expected_rows); + + Ok(MetricsFinder::find_metrics(plan.as_ref()).unwrap()) + } +} + +// Holds necessary data for these tests to reuse the same parquet file +struct TestData { + // field is present as on drop the file is deleted + #[allow(dead_code)] + temp_file: NamedTempFile, + schema: SchemaRef, + file_name: String, + file_size: u64, +} + +static TEST_DATA: OnceLock = OnceLock::new(); + +/// Return a parquet file with 2 row groups each with 5 rows +fn get_test_data() -> &'static TestData { + TEST_DATA.get_or_init(|| { + let scenario = Scenario::UTF8; + let row_per_group = 5; + + let mut temp_file = tempfile::Builder::new() + .prefix("user_access_plan") + .suffix(".parquet") + .tempfile() + .expect("tempfile creation"); + + let props = WriterProperties::builder() + .set_max_row_group_size(row_per_group) + .build(); + + let batches = create_data_batch(scenario); + let schema = batches[0].schema(); + + let mut writer = + ArrowWriter::try_new(&mut temp_file, schema.clone(), Some(props)).unwrap(); + + for batch in batches { + writer.write(&batch).expect("writing batch"); + } + writer.close().unwrap(); + + let file_name = temp_file.path().to_string_lossy().to_string(); + let file_size = temp_file.path().metadata().unwrap().len(); + + TestData { + temp_file, + schema, + file_name, + file_size, + } + }) +} + +/// Return the total value of the specified metric name +fn metric_value(parquet_metrics: &MetricsSet, metric_name: &str) -> Option { + parquet_metrics + .sum(|metric| metric.value().name() == metric_name) + .map(|v| v.as_usize()) +} diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 99769a3367228..7c8a96caa47b7 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -16,6 +16,7 @@ // under the License. //! Parquet integration tests +use crate::parquet::utils::MetricsFinder; use arrow::array::Decimal128Array; use arrow::datatypes::{ i256, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, @@ -41,8 +42,8 @@ use arrow_array::{ use arrow_schema::IntervalUnit; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ - datasource::{physical_plan::ParquetExec, provider_as_source, TableProvider}, - physical_plan::{accept, metrics::MetricsSet, ExecutionPlan, ExecutionPlanVisitor}, + datasource::{provider_as_source, TableProvider}, + physical_plan::metrics::MetricsSet, prelude::{ParquetReadOptions, SessionConfig, SessionContext}, }; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; @@ -51,8 +52,10 @@ use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; use std::sync::Arc; use tempfile::NamedTempFile; + mod arrow_statistics; mod custom_reader; +mod external_access_plan; mod file_statistics; #[cfg(not(target_family = "windows"))] mod filter_pushdown; @@ -60,6 +63,7 @@ mod page_pruning; mod row_group_pruning; mod schema; mod schema_coercion; +mod utils; #[cfg(test)] #[ctor::ctor] @@ -303,25 +307,8 @@ impl ContextWithParquet { .expect("Running"); // find the parquet metrics - struct MetricsFinder { - metrics: Option, - } - impl ExecutionPlanVisitor for MetricsFinder { - type Error = std::convert::Infallible; - fn pre_visit( - &mut self, - plan: &dyn ExecutionPlan, - ) -> Result { - if plan.as_any().downcast_ref::().is_some() { - self.metrics = plan.metrics(); - } - // stop searching once we have found the metrics - Ok(self.metrics.is_none()) - } - } - let mut finder = MetricsFinder { metrics: None }; - accept(physical_plan.as_ref(), &mut finder).unwrap(); - let parquet_metrics = finder.metrics.unwrap(); + let parquet_metrics = + MetricsFinder::find_metrics(physical_plan.as_ref()).unwrap(); let result_rows = results.iter().map(|b| b.num_rows()).sum(); diff --git a/datafusion/core/tests/parquet/utils.rs b/datafusion/core/tests/parquet/utils.rs new file mode 100644 index 0000000000000..d8d2b2fbb8a55 --- /dev/null +++ b/datafusion/core/tests/parquet/utils.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for parquet tests + +use datafusion::datasource::physical_plan::ParquetExec; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::{accept, ExecutionPlan, ExecutionPlanVisitor}; + +/// Find the metrics from the first ParquetExec encountered in the plan +#[derive(Debug)] +pub struct MetricsFinder { + metrics: Option, +} +impl MetricsFinder { + pub fn new() -> Self { + Self { metrics: None } + } + + /// Return the metrics if found + pub fn into_metrics(self) -> Option { + self.metrics + } + + pub fn find_metrics(plan: &dyn ExecutionPlan) -> Option { + let mut finder = Self::new(); + accept(plan, &mut finder).unwrap(); + finder.into_metrics() + } +} + +impl ExecutionPlanVisitor for MetricsFinder { + type Error = std::convert::Infallible; + fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { + if plan.as_any().downcast_ref::().is_some() { + self.metrics = plan.metrics(); + } + // stop searching once we have found the metrics + Ok(self.metrics.is_none()) + } +} diff --git a/datafusion/core/tests/parquet_exec.rs b/datafusion/core/tests/parquet_exec.rs index 43ceb615a0623..f41f82a76c67f 100644 --- a/datafusion/core/tests/parquet_exec.rs +++ b/datafusion/core/tests/parquet_exec.rs @@ -15,5 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! End to end test for `ParquetExec` and related components + /// Run all tests that are found in the `parquet` directory mod parquet;