diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index e8898c1557a8a..dec9a9136a5d8 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -29,7 +29,7 @@ use crate::physical_plan::{ }; pub use arrow::compute::SortOptions; use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions}; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, error::ArrowError}; @@ -201,15 +201,6 @@ fn sort_batch( None, )?; - let schema = Arc::new(Schema::new( - schema - .fields() - .iter() - .zip(batch.columns().iter().map(|col| col.data_type())) - .map(|(field, ty)| Field::new(field.name(), ty.clone(), field.is_nullable())) - .collect::>(), - )); - // reorder all rows based on sorted indices RecordBatch::try_new( schema, @@ -318,6 +309,8 @@ impl RecordBatchStream for SortStream { #[cfg(test)] mod tests { + use std::collections::{BTreeMap, HashMap}; + use super::*; use crate::datasource::object_store::local::LocalFileSystem; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -398,6 +391,57 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_sort_metadata() -> Result<()> { + let field_metadata: BTreeMap = + vec![("foo".to_string(), "bar".to_string())] + .into_iter() + .collect(); + let schema_metadata: HashMap = + vec![("baz".to_string(), "barf".to_string())] + .into_iter() + .collect(); + + let mut field = Field::new("field_name", DataType::UInt64, true); + field.set_metadata(Some(field_metadata.clone())); + let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone()); + let schema = Arc::new(schema); + + let data: ArrayRef = + Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::()); + + let batch = RecordBatch::try_new(schema.clone(), vec![data]).unwrap(); + let input = + Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None).unwrap()); + + let sort_exec = Arc::new(SortExec::try_new( + vec![PhysicalSortExpr { + expr: col("field_name", &schema)?, + options: SortOptions::default(), + }], + input, + )?); + + let result: Vec = collect(sort_exec).await?; + + let expected_data: ArrayRef = + Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); + let expected_batch = + RecordBatch::try_new(schema.clone(), vec![expected_data]).unwrap(); + + // Data is correct + assert_eq!(&vec![expected_batch], &result); + + // explicitlty ensure the metadata is present + assert_eq!( + result[0].schema().fields()[0].metadata(), + &Some(field_metadata) + ); + assert_eq!(result[0].schema().metadata(), &schema_metadata); + + Ok(()) + } + #[tokio::test] async fn test_lex_sort_by_float() -> Result<()> { let schema = Arc::new(Schema::new(vec![