diff --git a/Cargo.lock b/Cargo.lock index eb6506ea..793b04c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2024,6 +2024,7 @@ dependencies = [ "rand 0.9.2", "reqwest", "structopt", + "test-case", "tokio", "tokio-stream", "tonic", @@ -5658,6 +5659,39 @@ dependencies = [ "winapi", ] +[[package]] +name = "test-case" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb2550dd13afcd286853192af8601920d959b14c401fcece38071d53bf0768a8" +dependencies = [ + "test-case-macros", +] + +[[package]] +name = "test-case-core" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adcb7fd841cd518e279be3d5a3eb0636409487998a4aff22f3de87b81e88384f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.114", +] + +[[package]] +name = "test-case-macros" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", + "test-case-core", +] + [[package]] name = "textwrap" version = "0.11.0" diff --git a/Cargo.toml b/Cargo.toml index 7b45f06d..59ceb984 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,3 +82,4 @@ hyper-util = "0.1.16" pretty_assertions = "1.4" reqwest = "0.12" zip = "6.0" +test-case = "3.3.1" diff --git a/src/lib.rs b/src/lib.rs index 3877b209..4653d74d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ pub use flight_service::{ DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, Worker, WorkerQueryContext, WorkerSessionBuilder, }; -pub use metrics::rewrite_distributed_plan_with_metrics; +pub use metrics::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics}; pub use networking::{ BoxCloneSyncChannel, ChannelResolver, DefaultChannelResolver, WorkerResolver, create_flight_client, diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 655add1e..36fba44a 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -2,4 +2,4 @@ pub(crate) mod proto; mod task_metrics_collector; mod task_metrics_rewriter; pub(crate) use task_metrics_collector::{MetricsCollectorResult, TaskMetricsCollector}; -pub use task_metrics_rewriter::rewrite_distributed_plan_with_metrics; +pub use task_metrics_rewriter::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics}; diff --git a/src/metrics/proto.rs b/src/metrics/proto.rs index 9897d29e..6bf2fbda 100644 --- a/src/metrics/proto.rs +++ b/src/metrics/proto.rs @@ -218,6 +218,174 @@ pub fn df_metrics_set_to_proto( Ok(MetricsSetProto { metrics }) } +/// Converts a [MetricsSet] to a [MetricsSet], but renames all metrics to have a "_{task_id}" suffix. +/// ***Custom metrics are not supported - they will NOT be renamed*** +/// +/// Specific metrics like [OutputRows] will be transformed to general metrics like [Count]. +/// +/// We do it this way because, by default, [DisplayableExecutionPlan::with_metrics] aggregates +/// metrics by name (see https://github.com/apache/datafusion/blob/f0e38df39e13921ae19e79d26760c6554466955c/datafusion/physical-plan/src/display.rs#L425). +/// In some cases, we want to show metrics for each task seprately, so we rename metrics to have +/// the task id in them so they do not get aggregated together. +/// +/// Notably, [DisplayableExecutionPlan::with_full_metrics] exists, but this is too verbose, as it +/// will show metrics for each partition and include arbitrary labels. Renaming allows us to +/// achieve some medium-level of verbosity. +pub fn annotate_metrics_set_with_task_id(metrics_set: &MetricsSet, task_id: u64) -> MetricsSet { + let mut result = MetricsSet::new(); + + for metric in metrics_set.iter() { + let partition = metric.partition(); + let labels = metric.labels().to_vec(); + + let base_name = metric.value().name(); + let new_name = Cow::Owned(format!("{base_name}_{task_id}")); + + let new_metric_value = match metric.value() { + MetricValue::OutputRows(count) => { + let new_count = Count::new(); + new_count.add(count.value()); + MetricValue::Count { + name: new_name, + count: new_count, + } + } + MetricValue::ElapsedCompute(time) => { + let new_time = Time::new(); + new_time.add_duration(std::time::Duration::from_nanos(time.value() as u64)); + MetricValue::Time { + name: new_name, + time: new_time, + } + } + MetricValue::SpillCount(count) => { + let new_count = Count::new(); + new_count.add(count.value()); + MetricValue::Count { + name: new_name, + count: new_count, + } + } + MetricValue::SpilledBytes(count) => { + let new_count = Count::new(); + new_count.add(count.value()); + MetricValue::Count { + name: new_name, + count: new_count, + } + } + MetricValue::SpilledRows(count) => { + let new_count = Count::new(); + new_count.add(count.value()); + MetricValue::Count { + name: new_name, + count: new_count, + } + } + MetricValue::CurrentMemoryUsage(gauge) => { + let new_gauge = Gauge::new(); + new_gauge.set(gauge.value()); + MetricValue::Gauge { + name: new_name, + gauge: new_gauge, + } + } + MetricValue::Count { count, .. } => { + let new_count = Count::new(); + new_count.add(count.value()); + MetricValue::Count { + name: new_name, + count: new_count, + } + } + MetricValue::Gauge { gauge, .. } => { + let new_gauge = Gauge::new(); + new_gauge.set(gauge.value()); + MetricValue::Gauge { + name: new_name, + gauge: new_gauge, + } + } + MetricValue::Time { time, .. } => { + let new_time = Time::new(); + new_time.add_duration(std::time::Duration::from_nanos(time.value() as u64)); + MetricValue::Time { + name: new_name, + time: new_time, + } + } + MetricValue::StartTimestamp(timestamp) => { + let new_gauge = Gauge::new(); + if let Some(dt) = timestamp.value() { + new_gauge.set(dt.timestamp_nanos_opt().unwrap_or(0) as usize); + } + MetricValue::Gauge { + name: new_name, + gauge: new_gauge, + } + } + MetricValue::EndTimestamp(timestamp) => { + let new_gauge = Gauge::new(); + if let Some(dt) = timestamp.value() { + new_gauge.set(dt.timestamp_nanos_opt().unwrap_or(0) as usize); + } + MetricValue::Gauge { + name: new_name, + gauge: new_gauge, + } + } + MetricValue::OutputBytes(count) => { + let new_count = Count::new(); + new_count.add(count.value()); + MetricValue::Count { + name: new_name, + count: new_count, + } + } + MetricValue::OutputBatches(count) => { + let new_count = Count::new(); + new_count.add(count.value()); + MetricValue::Count { + name: new_name, + count: new_count, + } + } + MetricValue::PruningMetrics { + pruning_metrics, .. + } => { + // Convert to a count representing the matched value + let new_count = Count::new(); + new_count.add(pruning_metrics.matched()); + MetricValue::Count { + name: Cow::Owned(format!("{base_name}_matched_{task_id}")), + count: new_count, + } + } + MetricValue::Ratio { ratio_metrics, .. } => { + // Convert ratio to a gauge representing the percentage + let new_gauge = Gauge::new(); + if ratio_metrics.total() > 0 { + new_gauge.set((ratio_metrics.part() * 100) / ratio_metrics.total()); + } + MetricValue::Gauge { + name: new_name, + gauge: new_gauge, + } + } + // Skip custom metrics as they cannot be generically converted + MetricValue::Custom { .. } => continue, + }; + + result.push(Arc::new(Metric::new_with_labels( + new_metric_value, + partition, + labels, + ))); + } + + result +} + /// metrics_set_proto_to_df converts a [MetricsSetProto] to a [datafusion::physical_plan::metrics::MetricsSet]. pub fn metrics_set_proto_to_df( metrics_set_proto: &MetricsSetProto, @@ -1090,4 +1258,42 @@ mod tests { ))); test_roundtrip_helper(metrics_set, "ratio_metrics"); } + + #[test] + fn test_annotate_metrics_set_with_task_id_output_rows() { + // Create a MetricsSet with an OutputRows metric + let mut metrics_set = MetricsSet::new(); + let count = Count::new(); + count.add(1234); + let labels = vec![Label::new("operator", "scan")]; + metrics_set.push(Arc::new(Metric::new_with_labels( + MetricValue::OutputRows(count), + Some(0), + labels, + ))); + + let task_id = 42; + let annotated = annotate_metrics_set_with_task_id(&metrics_set, task_id); + + // Verify we have one metric + assert_eq!(annotated.iter().count(), 1); + + let metric = annotated.iter().next().unwrap(); + + // Verify OutputRows was converted to Count with task_id suffix + match metric.value() { + MetricValue::Count { name, count } => { + assert_eq!(name.as_ref(), "output_rows_42"); + assert_eq!(count.value(), 1234); + } + other => panic!("Expected Count, got {:?}", other.name()), + } + + // Verify labels and partition are preserved + assert_eq!(metric.partition(), Some(0)); + let labels: Vec<_> = metric.labels().iter().collect(); + assert_eq!(labels.len(), 1); + assert_eq!(labels[0].name(), "operator"); + assert_eq!(labels[0].value(), "scan"); + } } diff --git a/src/metrics/task_metrics_rewriter.rs b/src/metrics/task_metrics_rewriter.rs index ad011e5b..a35ac686 100644 --- a/src/metrics/task_metrics_rewriter.rs +++ b/src/metrics/task_metrics_rewriter.rs @@ -3,8 +3,9 @@ use crate::execution_plans::DistributedExec; use crate::execution_plans::MetricsWrapperExec; use crate::metrics::MetricsCollectorResult; use crate::metrics::TaskMetricsCollector; -use crate::metrics::proto::MetricsSetProto; -use crate::metrics::proto::metrics_set_proto_to_df; +use crate::metrics::proto::{ + MetricsSetProto, annotate_metrics_set_with_task_id, metrics_set_proto_to_df, +}; use crate::protobuf::StageKey; use crate::stage::Stage; use bytes::Bytes; @@ -19,10 +20,30 @@ use datafusion::physical_plan::metrics::MetricsSet; use std::sync::Arc; use std::vec; +/// Format to use when displaying metrics for a distributed plan. +#[derive(Clone, Copy)] +pub enum DistributedMetricsFormat { + /// Metrics are aggregated across all tasks. ex. a `output_rows=X` represents the output rows for all tasks. + Aggregated, + + /// Metric names are rewritten to include the task id. ex. `output_rows` -> `output_rows_0`, `output_rows_1` etc. + PerTask, +} + +impl DistributedMetricsFormat { + pub(crate) fn to_rewrite_ctx(self, task_id: u64) -> RewriteCtx { + match self { + DistributedMetricsFormat::Aggregated => RewriteCtx::from_task_id(task_id), + DistributedMetricsFormat::PerTask => RewriteCtx::default(), + } + } +} + /// Rewrites a distributed plan with metrics. Does nothing if the root node is not a [DistributedExec]. /// Returns an error if the distributed plan was not executed. pub fn rewrite_distributed_plan_with_metrics( plan: Arc, + format: DistributedMetricsFormat, ) -> Result> { let Some(distributed_exec) = plan.as_any().downcast_ref::() else { return Ok(plan); @@ -35,8 +56,11 @@ pub fn rewrite_distributed_plan_with_metrics( } = TaskMetricsCollector::new().collect(distributed_exec.prepared_plan()?)?; // Rewrite the DistributedExec's child plan with metrics. - let dist_exec_plan_with_metrics = - rewrite_local_plan_with_metrics(plan.children()[0].clone(), task_metrics)?; + let dist_exec_plan_with_metrics = rewrite_local_plan_with_metrics( + format.to_rewrite_ctx(0), // Task id is 0 for the DistributedExec plan + plan.children()[0].clone(), + task_metrics, + )?; let plan = plan.with_new_children(vec![dist_exec_plan_with_metrics])?; let metrics_collection = Arc::new(input_task_metrics); @@ -47,7 +71,8 @@ pub fn rewrite_distributed_plan_with_metrics( let stage = network_boundary.input_stage(); // This transform is a bit inefficient because we traverse the plan nodes twice // For now, we are okay with trading off performance for simplicity. - let plan_with_metrics = stage_metrics_rewriter(stage, metrics_collection.clone())?; + let plan_with_metrics = + stage_metrics_rewriter(stage, metrics_collection.clone(), format)?; return Ok(Transformed::yes(network_boundary.with_input_stage( Stage::new( stage.query_id, @@ -63,6 +88,29 @@ pub fn rewrite_distributed_plan_with_metrics( Ok(transformed.data) } +/// Extra information for rewriting local plans. +#[derive(Default)] +pub struct RewriteCtx { + /// Used to rename metrics for the current task. + pub task_id: Option, +} + +impl RewriteCtx { + pub(crate) fn from_task_id(task_id: u64) -> RewriteCtx { + RewriteCtx { + task_id: Some(task_id), + } + } + + /// Rewrites the [MetricsSet] depending on the context. + pub(crate) fn maybe_rewrite_node_metics(&self, node_metrics: MetricsSet) -> MetricsSet { + if let Some(task_id) = self.task_id { + return annotate_metrics_set_with_task_id(&node_metrics, task_id); + } + node_metrics + } +} + /// Rewrites a local plan with metrics, stopping at network boundaries. /// /// Example: @@ -77,6 +125,7 @@ pub fn rewrite_distributed_plan_with_metrics( /// └── MetricsWrapperExec (wrapped: ProjectionExec) [output_rows = 2, elapsed_compute = 200] /// └── NetworkShuffleExec pub fn rewrite_local_plan_with_metrics( + ctx: RewriteCtx, plan: Arc, metrics: Vec, ) -> Result> { @@ -89,7 +138,10 @@ pub fn rewrite_local_plan_with_metrics( if idx >= metrics.len() { return internal_err!("not enough metrics provided to rewrite plan"); } - let node_metrics = metrics[idx].clone(); + let mut node_metrics = metrics[idx].clone(); + + node_metrics = ctx.maybe_rewrite_node_metics(node_metrics); + idx += 1; Ok(Transformed::yes(Arc::new(MetricsWrapperExec::new( node.clone(), @@ -131,6 +183,7 @@ pub fn rewrite_local_plan_with_metrics( pub fn stage_metrics_rewriter( stage: &Stage, metrics_collection: Arc>>, + format: DistributedMetricsFormat, ) -> Result> { let mut node_idx = 0; @@ -143,10 +196,10 @@ pub fn stage_metrics_rewriter( } // Collect metrics for this node. It should contain metrics from each task. - let mut stage_metrics = MetricsSetProto::new(); + let mut stage_metrics = MetricsSet::new(); - for idx in 0..stage.tasks.len() { - let stage_key = StageKey::new(Bytes::from(stage.query_id.as_bytes().to_vec()), stage.num as u64, idx as u64); + for task_id in 0..stage.tasks.len() { + let stage_key = StageKey::new(Bytes::from(stage.query_id.as_bytes().to_vec()), stage.num as u64, task_id as u64); match metrics_collection.get(&stage_key) { Some(task_metrics) => { if node_idx >= task_metrics.len() { @@ -155,15 +208,20 @@ pub fn stage_metrics_rewriter( task_metrics.len() ); } - let node_metrics = task_metrics[node_idx].clone(); - for metric in node_metrics.metrics.iter() { - stage_metrics.push(metric.clone()); + let node_metrics_protos = task_metrics[node_idx].clone(); + let mut node_metrics = metrics_set_proto_to_df(&node_metrics_protos)?; + + let rewrite_ctx = format.to_rewrite_ctx(task_id as u64); + node_metrics = rewrite_ctx.maybe_rewrite_node_metics(node_metrics); + + for metric in node_metrics.iter().map(Arc::clone) { + stage_metrics.push(metric); } } None => { return internal_err!( "not enough metrics provided to rewrite task: missing metrics for task {} in stage {}", - idx, + task_id, stage.num ); } @@ -174,7 +232,7 @@ pub fn stage_metrics_rewriter( let wrapped_plan_node: Arc = Arc::new(MetricsWrapperExec::new( plan.clone(), - metrics_set_proto_to_df(&stage_metrics)?, + stage_metrics, )); Ok(Transformed::yes(wrapped_plan_node)) }).map(|v| v.data) @@ -184,10 +242,11 @@ pub fn stage_metrics_rewriter( mod tests { use crate::PartitionIsolatorExec; use crate::metrics::proto::{ - MetricsSetProto, df_metrics_set_to_proto, metrics_set_proto_to_df, + MetricsSetProto, annotate_metrics_set_with_task_id, df_metrics_set_to_proto, + metrics_set_proto_to_df, }; - use crate::metrics::rewrite_distributed_plan_with_metrics; use crate::metrics::task_metrics_rewriter::stage_metrics_rewriter; + use crate::metrics::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics}; use crate::protobuf::StageKey; use crate::test_utils::in_memory_channel_resolver::{ InMemoryChannelResolver, InMemoryWorkerResolver, @@ -364,7 +423,12 @@ mod tests { let metrics_collection = Arc::new(metrics_collection); // Rewrite the plan. - let rewritten_plan = stage_metrics_rewriter(&stage, metrics_collection.clone()).unwrap(); + let rewritten_plan = stage_metrics_rewriter( + &stage, + metrics_collection.clone(), + DistributedMetricsFormat::Aggregated, + ) + .unwrap(); // Collect metrics from the plan. let mut actual_metrics = vec![]; @@ -394,8 +458,13 @@ mod tests { actual_task_node_metrics_set .for_each(|metric| actual_metrics_set.push(metric.clone())); - let expected_metrics_set = - metrics_set_proto_to_df(&expected_task_node_metrics).unwrap(); // Convert to proto to check for equality. + // Convert from proto to check for equality. + let mut expected_metrics_set = + metrics_set_proto_to_df(&expected_task_node_metrics).unwrap(); + // Add task ids labels. We expect the actual metrics to be annotated by the + // rewriter + expected_metrics_set = + annotate_metrics_set_with_task_id(&expected_metrics_set, task_id as u64); assert!(metrics_set_eq(&actual_metrics_set, &expected_metrics_set)); } } @@ -440,7 +509,10 @@ mod tests { .await .unwrap(); assert!(plan.as_any().is::()); - assert!(rewrite_distributed_plan_with_metrics(plan).is_err()); + assert!( + rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated) + .is_err() + ); } // Assert every plan node has at least one metric except partition isolators, network boundary nodes, and the root DistributedExec node. @@ -481,7 +553,9 @@ mod tests { .unwrap(); collect(plan.clone(), ctx.task_ctx()).await.unwrap(); assert!(plan.as_any().is::()); - let rewritten_plan = rewrite_distributed_plan_with_metrics(plan).unwrap(); + let rewritten_plan = + rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated) + .unwrap(); assert_metrics_present_in_plan(&rewritten_plan); } diff --git a/src/stage.rs b/src/stage.rs index dba7e56f..577d6393 100644 --- a/src/stage.rs +++ b/src/stage.rs @@ -162,7 +162,7 @@ impl Stage { } } -use crate::rewrite_distributed_plan_with_metrics; +use crate::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics}; use crate::{NetworkBoundary, NetworkBoundaryExt}; use bytes::Bytes; use datafusion::common::DataFusionError; @@ -185,13 +185,16 @@ use prost::Message; use std::fmt::Write; /// explain_analyze renders an [ExecutionPlan] with metrics. -pub fn explain_analyze(executed: Arc) -> Result { +pub fn explain_analyze( + executed: Arc, + format: DistributedMetricsFormat, +) -> Result { match executed.as_any().downcast_ref::() { None => Ok(DisplayableExecutionPlan::with_metrics(executed.as_ref()) .indent(true) .to_string()), Some(_) => { - let executed = rewrite_distributed_plan_with_metrics(executed.clone())?; + let executed = rewrite_distributed_plan_with_metrics(executed.clone(), format)?; Ok(display_plan_ascii(executed.as_ref(), true)) } } @@ -208,7 +211,12 @@ pub fn display_plan_ascii(plan: &dyn ExecutionPlan, show_metrics: bool) -> Strin display_ascii(Either::Left(plan), 0, show_metrics, &mut f).unwrap(); f } else { - displayable(plan).indent(true).to_string() + match show_metrics { + true => DisplayableExecutionPlan::with_metrics(plan) + .indent(true) + .to_string(), + false => displayable(plan).indent(true).to_string(), + } } } diff --git a/tests/join.rs b/tests/join.rs index 5fe97ebd..edc92280 100644 --- a/tests/join.rs +++ b/tests/join.rs @@ -72,7 +72,7 @@ mod tests { let (state, logical_plan) = df.into_parts(); let physical_plan = state.create_physical_plan(&logical_plan).await?; let distributed_plan = display_plan_ascii(physical_plan.as_ref(), false); - println!("\n——————— DISTRIBUTED PLAN ———————\n\n{}", distributed_plan); + println!("\n——————— DISTRIBUTED PLAN ———————\n\n{distributed_plan}"); let distributed_results = collect(physical_plan, state.task_ctx()).await?; pretty::print_batches(&distributed_results)?; diff --git a/tests/metrics_collection.rs b/tests/metrics_collection.rs index bfbd03b4..ea878c43 100644 --- a/tests/metrics_collection.rs +++ b/tests/metrics_collection.rs @@ -7,25 +7,32 @@ mod tests { use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; use datafusion_distributed::{ - DefaultSessionBuilder, DistributedExec, display_plan_ascii, + DefaultSessionBuilder, DistributedMetricsFormat, display_plan_ascii, rewrite_distributed_plan_with_metrics, }; use futures::TryStreamExt; - use itertools::Itertools; use std::sync::Arc; + use test_case::test_case; + #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")] + #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")] #[tokio::test] - async fn test_metrics_collection_in_aggregation() -> Result<(), Box> { + async fn test_metrics_collection_in_aggregation( + format: DistributedMetricsFormat, + ) -> Result<(), Box> { let (d_ctx, _guard) = start_localhost_context(3, DefaultSessionBuilder).await; let query = r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#; let s_ctx = SessionContext::default(); - let (s_physical, d_physical) = execute(&s_ctx, &d_ctx, query).await?; + let (s_physical, mut d_physical) = execute(&s_ctx, &d_ctx, query).await?; + d_physical = rewrite_distributed_plan_with_metrics(d_physical.clone(), format)?; + println!("{}", display_plan_ascii(s_physical.as_ref(), true)); + println!("{}", display_plan_ascii(d_physical.as_ref(), true)); assert_metrics_equal::( - ["output_rows", "bytes_scanned"], + ["output_rows", "output_bytes"], &s_physical, &d_physical, 0, @@ -34,8 +41,12 @@ mod tests { Ok(()) } + #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")] + #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")] #[tokio::test] - async fn test_metrics_collection_in_join() -> Result<(), Box> { + async fn test_metrics_collection_in_join( + format: DistributedMetricsFormat, + ) -> Result<(), Box> { let (d_ctx, _guard) = start_localhost_context(3, DefaultSessionBuilder).await; let query = r#" @@ -44,14 +55,14 @@ mod tests { AVG("MinTemp") as "MinTemp", "RainTomorrow" FROM weather - WHERE "RainToday" = 'yes' + WHERE "RainToday" = 'Yes' GROUP BY "RainTomorrow" ), b AS ( SELECT AVG("MaxTemp") as "MaxTemp", "RainTomorrow" FROM weather - WHERE "RainToday" = 'no' + WHERE "RainToday" = 'No' GROUP BY "RainTomorrow" ) SELECT @@ -63,13 +74,14 @@ mod tests { "#; let s_ctx = SessionContext::default(); - let (s_physical, d_physical) = execute(&s_ctx, &d_ctx, query).await?; + let (s_physical, mut d_physical) = execute(&s_ctx, &d_ctx, query).await?; + d_physical = rewrite_distributed_plan_with_metrics(d_physical.clone(), format)?; println!("{}", display_plan_ascii(s_physical.as_ref(), true)); println!("{}", display_plan_ascii(d_physical.as_ref(), true)); for data_source_index in 0..2 { assert_metrics_equal::( - ["output_rows", "bytes_scanned"], + ["output_rows", "output_bytes"], &s_physical, &d_physical, data_source_index, @@ -79,8 +91,12 @@ mod tests { Ok(()) } + #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")] + #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")] #[tokio::test] - async fn test_metrics_collection_in_union() -> Result<(), Box> { + async fn test_metrics_collection_in_union( + format: DistributedMetricsFormat, + ) -> Result<(), Box> { let (d_ctx, _guard) = start_localhost_context(3, DefaultSessionBuilder).await; let query = r#" @@ -96,19 +112,20 @@ mod tests { "#; let s_ctx = SessionContext::default(); - let (s_physical, d_physical) = execute(&s_ctx, &d_ctx, query).await?; + let (s_physical, mut d_physical) = execute(&s_ctx, &d_ctx, query).await?; + + d_physical = rewrite_distributed_plan_with_metrics(d_physical.clone(), format)?; println!("{}", display_plan_ascii(s_physical.as_ref(), true)); println!("{}", display_plan_ascii(d_physical.as_ref(), true)); for data_source_index in 0..5 { assert_metrics_equal::( - ["output_rows", "bytes_scanned"], + ["output_rows", "output_bytes"], &s_physical, &d_physical, data_source_index, ); } - Ok(()) } @@ -149,7 +166,6 @@ mod tests { execute_stream(d_physical.clone(), d_ctx.task_ctx())? .try_collect::>() .await?; - let d_physical = rewrite_distributed_plan_with_metrics(d_physical.clone())?; Ok((s_physical, d_physical)) } @@ -174,23 +190,15 @@ mod tests { .unwrap(); let metrics = metrics .unwrap_or_else(|| panic!("Could not find metrics for plan {}", T::static_name())); - let is_distributed = plan.as_any().is::(); - metrics + let summed = metrics .iter() - .find(|v| v.value().name() == metric_name) - .unwrap_or_else(|| { - panic!( - "{} Could not find metric '{metric_name}' in {}. Available metrics are: {:?}", - if is_distributed { - "(distributed)" - } else { - "(single node)" - }, - T::static_name(), - metrics.iter().map(|v| v.value().name()).collect_vec() - ) - }) - .value() - .as_usize() + .filter(|v| v.value().name().starts_with(metric_name)) + .map(|v| v.value().as_usize()) + .sum(); + assert!( + summed > 0, + "Sum of metric values is 0. Either the metric {metric_name} is not present or the test is too trivial" + ); + summed } }