diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs index e2be2a0e240a7..98317b3ff487f 100644 --- a/datafusion/src/physical_plan/projection.rs +++ b/datafusion/src/physical_plan/projection.rs @@ -21,6 +21,7 @@ //! projection expressions. `SELECT` without `FROM` will only evaluate expressions. use std::any::Any; +use std::collections::BTreeMap; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -63,13 +64,15 @@ impl ProjectionExec { let fields: Result> = expr .iter() - .map(|(e, name)| match input_schema.field_with_name(name) { - Ok(f) => Ok(f.clone()), - Err(_) => { - let dt = e.data_type(&input_schema)?; - let nullable = e.nullable(&input_schema)?; - Ok(Field::new(name, dt, nullable)) - } + .map(|(e, name)| { + let mut field = Field::new( + name, + e.data_type(&input_schema)?, + e.nullable(&input_schema)?, + ); + field.set_metadata(get_field_metadata(e, &input_schema)); + + Ok(field) }) .collect(); @@ -179,6 +182,24 @@ impl ExecutionPlan for ProjectionExec { } } +/// If e is a direct column reference, returns the field level +/// metadata for that field, if any. Otherwise returns None +fn get_field_metadata( + e: &Arc, + input_schema: &Schema, +) -> Option> { + let name = if let Some(column) = e.as_any().downcast_ref::() { + column.name() + } else { + return None; + }; + + input_schema + .field_with_name(name) + .ok() + .and_then(|f| f.metadata().as_ref().cloned()) +} + fn stats_projection( stats: Statistics, exprs: impl Iterator>, diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 945bb7ebc2eb8..0b1abbe2180c7 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -891,6 +891,29 @@ async fn projection_same_fields() -> Result<()> { Ok(()) } +#[tokio::test] +async fn projection_type_alias() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await?; + + // Query that aliases one column to the name of a different column + // that also has a different type (c1 == float32, c3 == boolean) + let sql = "SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c3 |", + "+---------+", + "| 0.00001 |", + "| 0.00002 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_float64() -> Result<()> { let mut ctx = ExecutionContext::new();