diff --git a/native/core/src/execution/datafusion/mod.rs b/native/core/src/execution/datafusion/mod.rs index 6f81ee9181..ca41fa0aa0 100644 --- a/native/core/src/execution/datafusion/mod.rs +++ b/native/core/src/execution/datafusion/mod.rs @@ -21,4 +21,5 @@ pub mod expressions; mod operators; pub mod planner; pub mod shuffle_writer; +pub(crate) mod spark_plan; mod util; diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 73541b0a4e..aa6b3e91cf 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -85,6 +85,7 @@ use datafusion::{ use datafusion_functions_nested::concat::ArrayAppend; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; +use crate::execution::datafusion::spark_plan::SparkPlan; use datafusion_comet_proto::{ spark_expression::{ self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, @@ -127,8 +128,8 @@ type PhyExprResult = Result, String)>, ExecutionError type PartitionPhyExprResult = Result>, ExecutionError>; struct JoinParameters { - pub left: Arc, - pub right: Arc, + pub left: Arc, + pub right: Arc, pub join_on: Vec<(Arc, Arc)>, pub join_filter: Option, pub join_type: DFJoinType, @@ -847,7 +848,13 @@ impl PhysicalPlanner { } } - /// Create a DataFusion physical plan from Spark physical plan. + /// Create a DataFusion physical plan from Spark physical plan. There is a level of + /// abstraction where a tree of SparkPlan nodes is returned. There is a 1:1 mapping from a + /// protobuf Operator (that represents a Spark operator) to a native SparkPlan struct. We + /// need this 1:1 mapping so that we can report metrics back to Spark. The native execution + /// plan that is generated for each Operator is sometimes a single ExecutionPlan, but in some + /// cases we generate a tree of ExecutionPlans and we need to collect metrics for all of these + /// plans so we store references to them in the SparkPlan struct. /// /// `inputs` is a vector of input source IDs. It is used to create `ScanExec`s. Each `ScanExec` /// will be assigned a unique ID from `inputs` and the ID will be used to identify the input @@ -861,11 +868,11 @@ impl PhysicalPlanner { /// /// Note that we return created `Scan`s which will be kept at JNI API. JNI calls will use it to /// feed in new input batch from Spark JVM side. - pub fn create_plan<'a>( + pub(crate) fn create_plan<'a>( &'a self, spark_plan: &'a Operator, inputs: &mut Vec>, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> Result<(Vec, Arc), ExecutionError> { let children = &spark_plan.children; match spark_plan.op_struct.as_ref().unwrap() { OpStruct::Projection(project) => { @@ -880,7 +887,14 @@ impl PhysicalPlanner { .map(|r| (r, format!("col_{}", idx))) }) .collect(); - Ok((scans, Arc::new(ProjectionExec::try_new(exprs?, child)?))) + let projection = Arc::new(ProjectionExec::try_new( + exprs?, + Arc::clone(&child.native_plan), + )?); + Ok(( + scans, + Arc::new(SparkPlan::new(spark_plan.plan_id, projection, vec![child])), + )) } OpStruct::Filter(filter) => { assert!(children.len() == 1); @@ -888,7 +902,14 @@ impl PhysicalPlanner { let predicate = self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?; - Ok((scans, Arc::new(FilterExec::try_new(predicate, child)?))) + let filter = Arc::new(FilterExec::try_new( + predicate, + Arc::clone(&child.native_plan), + )?); + Ok(( + scans, + Arc::new(SparkPlan::new(spark_plan.plan_id, filter, vec![child])), + )) } OpStruct::HashAgg(agg) => { assert!(children.len() == 1); @@ -920,13 +941,13 @@ impl PhysicalPlanner { let num_agg = agg.agg_exprs.len(); let aggr_expr = agg_exprs?.into_iter().map(Arc::new).collect(); - let aggregate = Arc::new( + let aggregate: Arc = Arc::new( datafusion::physical_plan::aggregates::AggregateExec::try_new( mode, group_by, aggr_expr, vec![None; num_agg], // no filter expressions - Arc::clone(&child), + Arc::clone(&child.native_plan), Arc::clone(&schema), )?, ); @@ -940,8 +961,11 @@ impl PhysicalPlanner { }) .collect(); - let exec: Arc = if agg.result_exprs.is_empty() { - aggregate + if agg.result_exprs.is_empty() { + Ok(( + scans, + Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), + )) } else { // For final aggregation, DF's hash aggregate exec doesn't support Spark's // aggregate result expressions like `COUNT(col) + 1`, but instead relying @@ -950,17 +974,34 @@ impl PhysicalPlanner { // // Note that `result_exprs` should only be set for final aggregation on the // Spark side. - Arc::new(ProjectionExec::try_new(result_exprs?, aggregate)?) - }; - - Ok((scans, exec)) + let projection = Arc::new(ProjectionExec::try_new( + result_exprs?, + Arc::clone(&aggregate), + )?); + Ok(( + scans, + Arc::new(SparkPlan::new_with_additional( + spark_plan.plan_id, + projection, + vec![child], + vec![aggregate], + )), + )) + } } OpStruct::Limit(limit) => { assert!(children.len() == 1); let num = limit.limit; let (scans, child) = self.create_plan(&children[0], inputs)?; - Ok((scans, Arc::new(LocalLimitExec::new(child, num as usize)))) + let limit = Arc::new(LocalLimitExec::new( + Arc::clone(&child.native_plan), + num as usize, + )); + Ok(( + scans, + Arc::new(SparkPlan::new(spark_plan.plan_id, limit, vec![child])), + )) } OpStruct::Sort(sort) => { assert!(children.len() == 1); @@ -978,11 +1019,20 @@ impl PhysicalPlanner { // SortExec fails in some cases if we do not unpack dictionary-encoded arrays, and // it would be more efficient if we could avoid that. // https://github.com/apache/datafusion-comet/issues/963 - let child = Self::wrap_in_copy_exec(child); + let child_copied = Self::wrap_in_copy_exec(Arc::clone(&child.native_plan)); + + let sort = Arc::new( + SortExec::new(LexOrdering::new(exprs?), Arc::clone(&child_copied)) + .with_fetch(fetch), + ); Ok(( scans, - Arc::new(SortExec::new(LexOrdering::new(exprs?), child).with_fetch(fetch)), + Arc::new(SparkPlan::new( + spark_plan.plan_id, + sort, + vec![Arc::clone(&child)], + )), )) } OpStruct::Scan(scan) => { @@ -1008,7 +1058,10 @@ impl PhysicalPlanner { // The `ScanExec` operator will take actual arrays from Spark during execution let scan = ScanExec::new(self.exec_context_id, input_source, &scan.source, data_types)?; - Ok((vec![scan.clone()], Arc::new(scan))) + Ok(( + vec![scan.clone()], + Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])), + )) } OpStruct::ShuffleWriter(writer) => { assert!(children.len() == 1); @@ -1017,14 +1070,20 @@ impl PhysicalPlanner { let partitioning = self .create_partitioning(writer.partitioning.as_ref().unwrap(), child.schema())?; + let shuffle_writer = Arc::new(ShuffleWriterExec::try_new( + Arc::clone(&child.native_plan), + partitioning, + writer.output_data_file.clone(), + writer.output_index_file.clone(), + )?); + Ok(( scans, - Arc::new(ShuffleWriterExec::try_new( - child, - partitioning, - writer.output_data_file.clone(), - writer.output_index_file.clone(), - )?), + Arc::new(SparkPlan::new( + spark_plan.plan_id, + shuffle_writer, + vec![Arc::clone(&child)], + )), )) } OpStruct::Expand(expand) => { @@ -1068,15 +1127,18 @@ impl PhysicalPlanner { // the data corruption. Note that we only need to copy the input batch // if the child operator is `ScanExec`, because other operators after `ScanExec` // will create new arrays for the output batch. - let child = if can_reuse_input_batch(&child) { - Arc::new(CopyExec::new(child, CopyMode::UnpackOrDeepCopy)) + let input = if can_reuse_input_batch(&child.native_plan) { + Arc::new(CopyExec::new( + Arc::clone(&child.native_plan), + CopyMode::UnpackOrDeepCopy, + )) } else { - child + Arc::clone(&child.native_plan) }; - + let expand = Arc::new(CometExpandExec::new(projections, input, schema)); Ok(( scans, - Arc::new(CometExpandExec::new(projections, child, schema)), + Arc::new(SparkPlan::new(spark_plan.plan_id, expand, vec![child])), )) } OpStruct::SortMergeJoin(join) => { @@ -1104,8 +1166,8 @@ impl PhysicalPlanner { .collect(); let join = Arc::new(SortMergeJoinExec::try_new( - join_params.left, - join_params.right, + Arc::clone(&join_params.left.native_plan), + Arc::clone(&join_params.right.native_plan), join_params.join_on, join_params.join_filter, join_params.join_type, @@ -1115,7 +1177,17 @@ impl PhysicalPlanner { false, )?); - Ok((scans, join)) + Ok(( + scans, + Arc::new(SparkPlan::new( + spark_plan.plan_id, + join, + vec![ + Arc::clone(&join_params.left), + Arc::clone(&join_params.right), + ], + )), + )) } OpStruct::HashJoin(join) => { let (join_params, scans) = self.parse_join_parameters( @@ -1131,8 +1203,8 @@ impl PhysicalPlanner { // to copy the input batch to avoid the data corruption from reusing the input // batch. We also need to unpack dictionary arrays, because the join operators // do not support them. - let left = Self::wrap_in_copy_exec(join_params.left); - let right = Self::wrap_in_copy_exec(join_params.right); + let left = Self::wrap_in_copy_exec(Arc::clone(&join_params.left.native_plan)); + let right = Self::wrap_in_copy_exec(Arc::clone(&join_params.right.native_plan)); let hash_join = Arc::new(HashJoinExec::try_new( left, @@ -1148,13 +1220,36 @@ impl PhysicalPlanner { )?); // If the hash join is build right, we need to swap the left and right - let hash_join = if join.build_side == BuildSide::BuildLeft as i32 { - hash_join + if join.build_side == BuildSide::BuildLeft as i32 { + Ok(( + scans, + Arc::new(SparkPlan::new( + spark_plan.plan_id, + hash_join, + vec![join_params.left, join_params.right], + )), + )) } else { - swap_hash_join(hash_join.as_ref(), PartitionMode::Partitioned)? - }; - - Ok((scans, hash_join)) + // we insert a projection around the hash join in this case + let projection = + swap_hash_join(hash_join.as_ref(), PartitionMode::Partitioned)?; + let swapped_hash_join = Arc::clone(projection.children()[0]); + let mut additional_native_plans = swapped_hash_join + .children() + .iter() + .map(|p| Arc::clone(p)) + .collect::>(); + additional_native_plans.push(Arc::clone(&swapped_hash_join)); + Ok(( + scans, + Arc::new(SparkPlan::new_with_additional( + spark_plan.plan_id, + projection, + vec![join_params.left, join_params.right], + additional_native_plans, + )), + )) + } } OpStruct::Window(wnd) => { let (scans, child) = self.create_plan(&children[0], inputs)?; @@ -1187,14 +1282,15 @@ impl PhysicalPlanner { }) .collect(); + let window_agg = Arc::new(BoundedWindowAggExec::try_new( + window_expr?, + Arc::clone(&child.native_plan), + partition_exprs.to_vec(), + InputOrderMode::Sorted, + )?); Ok(( scans, - Arc::new(BoundedWindowAggExec::try_new( - window_expr?, - child, - partition_exprs.to_vec(), - InputOrderMode::Sorted, - )?), + Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])), )) } } @@ -1331,8 +1427,8 @@ impl PhysicalPlanner { Ok(( JoinParameters { - left, - right, + left: Arc::clone(&left), + right: Arc::clone(&right), join_on, join_type, join_filter, @@ -2199,6 +2295,7 @@ mod tests { use crate::execution::operators::ExecutionError; use datafusion_comet_proto::{ spark_expression::expr::ExprStruct::*, + spark_expression::Expr, spark_expression::{self, literal}, spark_operator, spark_operator::{operator::OpStruct, Operator}, @@ -2207,6 +2304,7 @@ mod tests { #[test] fn test_unpack_dictionary_primitive() { let op_scan = Operator { + plan_id: 0, children: vec![], op_struct: Some(OpStruct::Scan(spark_operator::Scan { fields: vec![spark_expression::DataType { @@ -2232,7 +2330,7 @@ mod tests { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let mut stream = datafusion_plan.execute(0, task_ctx).unwrap(); + let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap(); let runtime = tokio::runtime::Runtime::new().unwrap(); let (tx, mut rx) = mpsc::channel(1); @@ -2279,6 +2377,7 @@ mod tests { #[test] fn test_unpack_dictionary_string() { let op_scan = Operator { + plan_id: 0, children: vec![], op_struct: Some(OpStruct::Scan(spark_operator::Scan { fields: vec![spark_expression::DataType { @@ -2315,7 +2414,7 @@ mod tests { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let mut stream = datafusion_plan.execute(0, task_ctx).unwrap(); + let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap(); let runtime = tokio::runtime::Runtime::new().unwrap(); let (tx, mut rx) = mpsc::channel(1); @@ -2364,19 +2463,7 @@ mod tests { #[tokio::test()] #[allow(clippy::field_reassign_with_default)] async fn to_datafusion_filter() { - let op_scan = spark_operator::Operator { - children: vec![], - op_struct: Some(spark_operator::operator::OpStruct::Scan( - spark_operator::Scan { - fields: vec![spark_expression::DataType { - type_id: 3, - type_info: None, - }], - source: "".to_string(), - }, - )), - }; - + let op_scan = create_scan(); let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); @@ -2388,7 +2475,10 @@ mod tests { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let stream = datafusion_plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + let stream = datafusion_plan + .native_plan + .execute(0, Arc::clone(&task_ctx)) + .unwrap(); let output = collect(stream).await.unwrap(); assert!(output.is_empty()); } @@ -2442,10 +2532,85 @@ mod tests { }; Operator { + plan_id: 0, children: vec![child_op], op_struct: Some(OpStruct::Filter(spark_operator::Filter { predicate: Some(expr), })), } } + + #[test] + fn spark_plan_metrics_filter() { + let op_scan = create_scan(); + let op = create_filter(op_scan, 0); + let planner = PhysicalPlanner::default(); + + let (mut _scans, filter_exec) = planner.create_plan(&op, &mut vec![]).unwrap(); + + assert_eq!("FilterExec", filter_exec.native_plan.name()); + assert_eq!(1, filter_exec.children.len()); + assert_eq!(1, filter_exec.additional_native_plans.len()); + assert_eq!("ScanExec", filter_exec.additional_native_plans[0].name()); + + let scan_exec = &filter_exec.children()[0]; + assert_eq!("ScanExec", scan_exec.native_plan.name()); + assert_eq!(0, scan_exec.additional_native_plans.len()); + } + + #[test] + fn spark_plan_metrics_hash_join() { + let op_scan = create_scan(); + let op_join = Operator { + plan_id: 0, + children: vec![op_scan.clone(), op_scan.clone()], + op_struct: Some(OpStruct::HashJoin(spark_operator::HashJoin { + left_join_keys: vec![create_bound_reference(0)], + right_join_keys: vec![create_bound_reference(0)], + join_type: 0, + condition: None, + build_side: 0, + })), + }; + + let planner = PhysicalPlanner::default(); + + let (_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![]).unwrap(); + + assert_eq!("HashJoinExec", hash_join_exec.native_plan.name()); + assert_eq!(2, hash_join_exec.children.len()); + assert_eq!("ScanExec", hash_join_exec.children[0].native_plan.name()); + assert_eq!("ScanExec", hash_join_exec.children[1].native_plan.name()); + + assert_eq!(2, hash_join_exec.additional_native_plans.len()); + assert_eq!("ScanExec", hash_join_exec.additional_native_plans[0].name()); + assert_eq!("ScanExec", hash_join_exec.additional_native_plans[1].name()); + } + + fn create_bound_reference(index: i32) -> Expr { + Expr { + expr_struct: Some(Bound(spark_expression::BoundReference { + index, + datatype: Some(create_proto_datatype()), + })), + } + } + + fn create_scan() -> Operator { + Operator { + plan_id: 0, + children: vec![], + op_struct: Some(OpStruct::Scan(spark_operator::Scan { + fields: vec![create_proto_datatype()], + source: "".to_string(), + })), + } + } + + fn create_proto_datatype() -> spark_expression::DataType { + spark_expression::DataType { + type_id: 3, + type_info: None, + } + } } diff --git a/native/core/src/execution/datafusion/spark_plan.rs b/native/core/src/execution/datafusion/spark_plan.rs new file mode 100644 index 0000000000..6660c5bc4f --- /dev/null +++ b/native/core/src/execution/datafusion/spark_plan.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::operators::{CopyExec, ScanExec}; +use arrow_schema::SchemaRef; +use datafusion::physical_plan::ExecutionPlan; +use std::sync::Arc; + +/// Wrapper around a native plan that maps to a Spark plan and can optionally contain +/// references to other native plans that should contribute to the Spark SQL metrics +/// for the root plan (such as CopyExec and ScanExec nodes) +#[derive(Debug, Clone)] +pub(crate) struct SparkPlan { + /// Spark plan ID (used for informational purposes only) + pub(crate) plan_id: u32, + /// The root of the native plan that was generated for this Spark plan + pub(crate) native_plan: Arc, + /// Child Spark plans + pub(crate) children: Vec>, + /// Additional native plans that were generated for this Spark plan that we need + /// to collect metrics for (such as CopyExec and ScanExec) + pub(crate) additional_native_plans: Vec>, +} + +impl SparkPlan { + /// Create a SparkPlan that consists of a single native plan + pub(crate) fn new( + plan_id: u32, + native_plan: Arc, + children: Vec>, + ) -> Self { + let mut additional_native_plans: Vec> = vec![]; + for child in &children { + collect_additional_plans(Arc::clone(&child.native_plan), &mut additional_native_plans); + } + Self { + plan_id, + native_plan, + children, + additional_native_plans, + } + } + + /// Create a SparkPlan that consists of more than one native plan + pub(crate) fn new_with_additional( + plan_id: u32, + native_plan: Arc, + children: Vec>, + additional_native_plans: Vec>, + ) -> Self { + let mut accum: Vec> = vec![]; + for plan in &additional_native_plans { + accum.push(Arc::clone(plan)); + } + for child in &children { + collect_additional_plans(Arc::clone(&child.native_plan), &mut accum); + } + Self { + plan_id, + native_plan, + children, + additional_native_plans: accum, + } + } + + /// Get the schema of the native plan + pub(crate) fn schema(&self) -> SchemaRef { + self.native_plan.schema() + } + + /// Get the child SparkPlan instances + pub(crate) fn children(&self) -> &Vec> { + &self.children + } +} + +fn collect_additional_plans( + child: Arc, + additional_native_plans: &mut Vec>, +) { + if child.as_any().is::() { + additional_native_plans.push(Arc::clone(&child)); + // CopyExec may be wrapping a ScanExec + collect_additional_plans(Arc::clone(child.children()[0]), additional_native_plans); + } else if child.as_any().is::() { + additional_native_plans.push(Arc::clone(&child)); + } +} diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 448f383c6b..083744f0a4 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -24,7 +24,7 @@ use datafusion::{ disk_manager::DiskManagerConfig, runtime_env::{RuntimeConfig, RuntimeEnv}, }, - physical_plan::{display::DisplayableExecutionPlan, ExecutionPlan, SendableRecordBatchStream}, + physical_plan::{display::DisplayableExecutionPlan, SendableRecordBatchStream}, prelude::{SessionConfig, SessionContext}, }; use futures::poll; @@ -59,6 +59,7 @@ use jni::{ }; use tokio::runtime::Runtime; +use crate::execution::datafusion::spark_plan::SparkPlan; use crate::execution::operators::ScanExec; use log::info; @@ -69,7 +70,7 @@ struct ExecutionContext { /// The deserialized Spark plan pub spark_plan: Operator, /// The DataFusion root operator converted from the `spark_plan` - pub root_op: Option>, + pub root_op: Option>, /// The input sources for the DataFusion plan pub scans: Vec, /// The global reference of input sources for the DataFusion plan @@ -360,7 +361,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( if exec_context.explain_native { let formatted_plan_str = - DisplayableExecutionPlan::new(root_op.as_ref()).indent(true); + DisplayableExecutionPlan::new(root_op.native_plan.as_ref()).indent(true); info!("Comet native query plan:\n{formatted_plan_str:}"); } @@ -369,6 +370,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( .root_op .as_ref() .unwrap() + .native_plan .execute(0, task_ctx)?; exec_context.stream = Some(stream); } else { @@ -400,12 +402,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( if exec_context.explain_native { if let Some(plan) = &exec_context.root_op { let formatted_plan_str = - DisplayableExecutionPlan::with_metrics(plan.as_ref()).indent(true); + DisplayableExecutionPlan::with_metrics(plan.native_plan.as_ref()) + .indent(true); info!( - "Comet native query plan with metrics:\ - \n[Stage {} Partition {}] plan creation (including CometScans fetching first batches) took {:?}:\ + "Comet native query plan with metrics (Plan #{} Stage {} Partition {}):\ + \n plan creation (including CometScans fetching first batches) took {:?}:\ \n{formatted_plan_str:}", - stage_id, partition, exec_context.plan_creation_time + plan.plan_id, stage_id, partition, exec_context.plan_creation_time ); } } diff --git a/native/core/src/execution/metrics/utils.rs b/native/core/src/execution/metrics/utils.rs index 9291f32c72..4bb1c4474c 100644 --- a/native/core/src/execution/metrics/utils.rs +++ b/native/core/src/execution/metrics/utils.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::execution::datafusion::spark_plan::SparkPlan; use crate::jvm_bridge::jni_new_global_ref; use crate::{ errors::CometError, jvm_bridge::{jni_call, jni_new_string}, }; -use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::metrics::MetricValue; use jni::objects::{GlobalRef, JString}; use jni::{objects::JObject, JNIEnv}; use std::collections::HashMap; @@ -28,18 +29,36 @@ use std::sync::Arc; /// Updates the metrics of a CometMetricNode. This function is called recursively to /// update the metrics of all the children nodes. The metrics are pulled from the -/// DataFusion execution plan and pushed to the Java side through JNI. +/// native execution plan and pushed to the Java side through JNI. pub fn update_comet_metric( env: &mut JNIEnv, metric_node: &JObject, - execution_plan: &Arc, + spark_plan: &Arc, metrics_jstrings: &mut HashMap>, ) -> Result<(), CometError> { + // combine all metrics from all native plans for this SparkPlan + let metrics = if spark_plan.additional_native_plans.is_empty() { + spark_plan.native_plan.metrics() + } else { + let mut metrics = spark_plan.native_plan.metrics().unwrap_or_default(); + for plan in &spark_plan.additional_native_plans { + let additional_metrics = plan.metrics().unwrap_or_default(); + for c in additional_metrics.iter() { + match c.value() { + MetricValue::OutputRows(_) => { + // we do not want to double count output rows + } + _ => metrics.push(c.to_owned()), + } + } + } + Some(metrics.aggregate_by_name()) + }; + update_metrics( env, metric_node, - &execution_plan - .metrics() + &metrics .unwrap_or_default() .iter() .map(|m| m.value()) @@ -49,7 +68,7 @@ pub fn update_comet_metric( )?; unsafe { - for (i, child_plan) in execution_plan.children().iter().enumerate() { + for (i, child_plan) in spark_plan.children().iter().enumerate() { let child_metric_node: JObject = jni_call!(env, comet_metric_node(metric_node).get_child_node(i as i32) -> JObject )?; diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 533d504c4f..74ec80cb54 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -31,6 +31,9 @@ message Operator { // The child operators of this repeated Operator children = 1; + // Spark plan ID + uint32 plan_id = 2; + oneof op_struct { Scan scan = 100; Projection projection = 101; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f7d5fc91a0..2bb467af58 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2508,7 +2508,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim */ def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = { val conf = op.conf - val result = OperatorOuterClass.Operator.newBuilder() + val result = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id) childOp.foreach(result.addChildren) op match {