diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index e798751b3353..f3cdd048968b 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::DataType; +use datafusion_common::logical_type::LogicalType; use std::sync::Arc; use datafusion::error::Result; @@ -48,7 +48,7 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { scalar_subquery(Arc::new( ctx.table("t2") .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? + .filter(out_ref_col(LogicalType::Utf8, "t1.c1").eq(col("t2.c1")))? .aggregate(vec![], vec![avg(col("t2.c2"))])? .select(vec![avg(col("t2.c2"))])? .into_unoptimized_plan(), @@ -91,7 +91,7 @@ async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { .filter(exists(Arc::new( ctx.table("t2") .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? + .filter(out_ref_col(LogicalType::Utf8, "t1.c1").eq(col("t2.c1")))? .select(vec![col("t2.c2")])? .into_unoptimized_plan(), )))? diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 591f6ac3de95..e2b54fb68daf 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -28,6 +28,9 @@ use datafusion::functions_aggregate::first_last::first_value_udaf; use datafusion::optimizer::simplify_expressions::ExprSimplifier; use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries}; use datafusion::prelude::*; +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::schema::LogicalSchema; +use datafusion_common::logical_type::LogicalType; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; @@ -156,7 +159,7 @@ fn simplify_demo() -> Result<()> { // However, DataFusion's simplification logic can do this for you // you need to tell DataFusion the type of column "ts": - let schema = Schema::new(vec![make_ts_field("ts")]).to_dfschema_ref()?; + let schema = LogicalSchema::from(Schema::new(vec![make_ts_field("ts")])).to_dfschema_ref()?; // And then build a simplifier // the ExecutionProps carries information needed to simplify @@ -177,10 +180,10 @@ fn simplify_demo() -> Result<()> { ); // here are some other examples of what DataFusion is capable of - let schema = Schema::new(vec![ + let schema = LogicalSchema::from(Schema::new(vec![ make_field("i", DataType::Int64), make_field("b", DataType::Boolean), - ]) + ])) .to_dfschema_ref()?; let context = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(context); @@ -211,7 +214,7 @@ fn simplify_demo() -> Result<()> { // String --> Date simplification // `cast('2020-09-01' as date)` --> 18500 assert_eq!( - simplifier.simplify(lit("2020-09-01").cast_to(&DataType::Date32, &schema)?)?, + simplifier.simplify(lit("2020-09-01").cast_to(&LogicalType::Date32, &schema)?)?, lit(ScalarValue::Date32(Some(18506))) ); @@ -258,7 +261,7 @@ fn range_analysis_demo() -> Result<()> { let analysis_result = analyze( &physical_expr, AnalysisContext::new(boundaries), - df_schema.as_ref(), + &df_schema.into(), )?; // The results of the analysis is an range, encoded as an `Interval`, for @@ -293,14 +296,14 @@ fn expression_type_demo() -> Result<()> { // a schema. In this case we create a schema where the column `c` is of // type Utf8 (a String / VARCHAR) let schema = DFSchema::from_unqualifed_fields( - vec![Field::new("c", DataType::Utf8, true)].into(), + vec![LogicalField::new("c", LogicalType::Utf8, true)].into(), HashMap::new(), )?; assert_eq!("Utf8", format!("{}", expr.get_type(&schema).unwrap())); // Using a schema where the column `foo` is of type Int32 let schema = DFSchema::from_unqualifed_fields( - vec![Field::new("c", DataType::Int32, true)].into(), + vec![LogicalField::new("c", LogicalType::Int32, true)].into(), HashMap::new(), )?; assert_eq!("Int32", format!("{}", expr.get_type(&schema).unwrap())); @@ -310,8 +313,8 @@ fn expression_type_demo() -> Result<()> { let expr = col("c1") + col("c2"); let schema = DFSchema::from_unqualifed_fields( vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Float32, true), + LogicalField::new("c1", LogicalType::Int32, true), + LogicalField::new("c2", LogicalType::Float32, true), ] .into(), HashMap::new(), diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index f57b3bf60404..bf2a44e7063e 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -22,6 +22,7 @@ use datafusion::error::Result; use datafusion::execution::context::{ FunctionFactory, RegisterFunction, SessionContext, SessionState, }; +use datafusion_common::logical_type::extension::ExtensionType; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{exec_err, internal_err, DataFusionError}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; @@ -216,13 +217,14 @@ impl TryFrom for ScalarFunctionWrapper { .expect("Expression has to be defined!"), return_type: definition .return_type - .expect("Return type has to be defined!"), + .expect("Return type has to be defined!") + .physical_type(), signature: Signature::exact( definition .args .unwrap_or_default() .into_iter() - .map(|a| a.data_type) + .map(|a| a.data_type.physical_type()) .collect(), definition .params diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 06286d5d66ed..ed748288ff3d 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -17,6 +17,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; +use datafusion_common::logical_type::LogicalType; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ @@ -211,7 +212,7 @@ impl ContextProvider for MyContextProvider { None } - fn get_variable_type(&self, _variable_names: &[String]) -> Option { + fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index e36a4f890644..ac54e87fd617 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -17,8 +17,6 @@ //! Column -use arrow_schema::{Field, FieldRef}; - use crate::error::_schema_err; use crate::utils::{parse_identifiers_normalized, quote_identifier}; use crate::{DFSchema, DataFusionError, Result, SchemaError, TableReference}; @@ -27,6 +25,7 @@ use std::convert::Infallible; use std::fmt; use std::str::FromStr; use std::sync::Arc; +use crate::logical_type::field::{LogicalField, LogicalFieldRef}; /// A named reference to a qualified field in a schema. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -349,15 +348,15 @@ impl From for Column { } /// Create a column, use qualifier and field name -impl From<(Option<&TableReference>, &Field)> for Column { - fn from((relation, field): (Option<&TableReference>, &Field)) -> Self { +impl From<(Option<&TableReference>, &LogicalField)> for Column { + fn from((relation, field): (Option<&TableReference>, &LogicalField)) -> Self { Self::new(relation.cloned(), field.name()) } } /// Create a column, use qualifier and field name -impl From<(Option<&TableReference>, &FieldRef)> for Column { - fn from((relation, field): (Option<&TableReference>, &FieldRef)) -> Self { +impl From<(Option<&TableReference>, &LogicalFieldRef)> for Column { + fn from((relation, field): (Option<&TableReference>, &LogicalFieldRef)) -> Self { Self::new(relation.cloned(), field.name()) } } @@ -380,7 +379,7 @@ impl fmt::Display for Column { mod tests { use super::*; use arrow::datatypes::DataType; - use arrow_schema::SchemaBuilder; + use arrow_schema::{Field, SchemaBuilder}; fn create_qualified_schema(qualifier: &str, names: Vec<&str>) -> Result { let mut schema_builder = SchemaBuilder::new(); @@ -389,7 +388,7 @@ mod tests { .iter() .map(|f| Field::new(*f, DataType::Boolean, true)), ); - let schema = Arc::new(schema_builder.finish()); + let schema = Arc::new(schema_builder.finish().into()); DFSchema::try_from_qualified_schema(qualifier, &schema) } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 0dab13d08731..f407383319ac 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -30,8 +30,12 @@ use crate::{ }; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; -use arrow_schema::SchemaBuilder; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use crate::logical_type::extension::ExtensionType; +use crate::logical_type::field::{LogicalField, LogicalFieldRef}; +use crate::logical_type::fields::LogicalFields; +use crate::logical_type::LogicalType; +use crate::logical_type::schema::{LogicalSchema, LogicalSchemaBuilder, LogicalSchemaRef}; /// A reference-counted reference to a [DFSchema]. pub type DFSchemaRef = Arc; @@ -62,7 +66,7 @@ pub type DFSchemaRef = Arc; /// Field::new("c1", DataType::Int32, false), /// ]); /// -/// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema).unwrap(); +/// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema.into()).unwrap(); /// let column = Column::from_qualified_name("t1.c1"); /// assert!(df_schema.has_column(&column)); /// @@ -107,9 +111,9 @@ pub type DFSchemaRef = Arc; #[derive(Debug, Clone, PartialEq, Eq)] pub struct DFSchema { /// Inner Arrow schema reference. - inner: SchemaRef, + inner: LogicalSchemaRef, /// Optional qualifiers for each column in this schema. In the same order as - /// the `self.inner.fields()` + /// the `self.inner.fields` field_qualifiers: Vec>, /// Stores functional dependencies in the schema. functional_dependencies: FunctionalDependencies, @@ -119,7 +123,7 @@ impl DFSchema { /// Creates an empty `DFSchema` pub fn empty() -> Self { Self { - inner: Arc::new(Schema::new([])), + inner: Arc::new(LogicalSchema::new([])), field_qualifiers: vec![], functional_dependencies: FunctionalDependencies::empty(), } @@ -128,26 +132,26 @@ impl DFSchema { /// Return a reference to the inner Arrow [`Schema`] /// /// Note this does not have the qualifier information - pub fn as_arrow(&self) -> &Schema { + pub fn as_arrow(&self) -> &LogicalSchema { self.inner.as_ref() } /// Return a reference to the inner Arrow [`SchemaRef`] /// /// Note this does not have the qualifier information - pub fn inner(&self) -> &SchemaRef { + pub fn inner(&self) -> &LogicalSchemaRef { &self.inner } /// Create a `DFSchema` from an Arrow schema where all the fields have a given qualifier pub fn new_with_metadata( - qualified_fields: Vec<(Option, Arc)>, + qualified_fields: Vec<(Option, Arc)>, metadata: HashMap, ) -> Result { - let (qualifiers, fields): (Vec>, Vec>) = + let (qualifiers, fields): (Vec>, Vec>) = qualified_fields.into_iter().unzip(); - let schema = Arc::new(Schema::new_with_metadata(fields, metadata)); + let schema = Arc::new(LogicalSchema::new_with_metadata(fields, metadata)); let dfschema = Self { inner: schema, @@ -160,11 +164,11 @@ impl DFSchema { /// Create a new `DFSchema` from a list of Arrow [Field]s pub fn from_unqualifed_fields( - fields: Fields, + fields: LogicalFields, metadata: HashMap, ) -> Result { let field_count = fields.len(); - let schema = Arc::new(Schema::new_with_metadata(fields, metadata)); + let schema = Arc::new(LogicalSchema::new_with_metadata(fields, metadata)); let dfschema = Self { inner: schema, field_qualifiers: vec![None; field_count], @@ -180,7 +184,7 @@ impl DFSchema { /// `DFSchema::try_from`. pub fn try_from_qualified_schema( qualifier: impl Into, - schema: &Schema, + schema: &LogicalSchema, ) -> Result { let qualifier = qualifier.into(); let schema = DFSchema { @@ -195,7 +199,7 @@ impl DFSchema { /// Create a `DFSchema` from an Arrow schema where all the fields have a given qualifier pub fn from_field_specific_qualified_schema( qualifiers: Vec>, - schema: &SchemaRef, + schema: &LogicalSchemaRef, ) -> Result { let dfschema = Self { inner: schema.clone(), @@ -211,7 +215,7 @@ impl DFSchema { let mut qualified_names = BTreeSet::new(); let mut unqualified_names = BTreeSet::new(); - for (field, qualifier) in self.inner.fields().iter().zip(&self.field_qualifiers) { + for (field, qualifier) in self.inner.fields.iter().zip(&self.field_qualifiers) { if let Some(qualifier) = qualifier { qualified_names.insert((qualifier, field.name())); } else if !unqualified_names.insert(field.name()) { @@ -250,8 +254,8 @@ impl DFSchema { /// Create a new schema that contains the fields from this schema followed by the fields /// from the supplied schema. An error will be returned if there are duplicate field names. pub fn join(&self, schema: &DFSchema) -> Result { - let mut schema_builder = SchemaBuilder::new(); - schema_builder.extend(self.inner.fields().iter().cloned()); + let mut schema_builder = LogicalSchemaBuilder::new(); + schema_builder.extend(self.inner.fields.iter().cloned()); schema_builder.extend(schema.fields().iter().cloned()); let new_schema = schema_builder.finish(); @@ -278,23 +282,23 @@ impl DFSchema { return; } - let self_fields: HashSet<(Option<&TableReference>, &FieldRef)> = + let self_fields: HashSet<(Option<&TableReference>, &LogicalFieldRef)> = self.iter().collect(); let self_unqualified_names: HashSet<&str> = self .inner .fields .iter() - .map(|field| field.name().as_str()) + .map(|field| field.name()) .collect(); - let mut schema_builder = SchemaBuilder::from(self.inner.fields.clone()); + let mut schema_builder = LogicalSchemaBuilder::from(self.inner.fields.clone()); let mut qualifiers = Vec::new(); for (qualifier, field) in other_schema.iter() { // skip duplicate columns let duplicated_field = match qualifier { Some(q) => self_fields.contains(&(Some(q), field)), // for unqualified columns, check as unqualified name - None => self_unqualified_names.contains(field.name().as_str()), + None => self_unqualified_names.contains(field.name()), }; if !duplicated_field { // self.inner.fields.push(field.clone()); @@ -312,19 +316,19 @@ impl DFSchema { } /// Get a list of fields - pub fn fields(&self) -> &Fields { + pub fn fields(&self) -> &LogicalFields { &self.inner.fields } /// Returns an immutable reference of a specific `Field` instance selected using an /// offset within the internal `fields` vector - pub fn field(&self, i: usize) -> &Field { + pub fn field(&self, i: usize) -> &LogicalField { &self.inner.fields[i] } /// Returns an immutable reference of a specific `Field` instance selected using an /// offset within the internal `fields` vector and its qualifier - pub fn qualified_field(&self, i: usize) -> (Option<&TableReference>, &Field) { + pub fn qualified_field(&self, i: usize) -> (Option<&TableReference>, &LogicalField) { (self.field_qualifiers[i].as_ref(), self.field(i)) } @@ -391,7 +395,7 @@ impl DFSchema { &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result<&Field> { + ) -> Result<&LogicalField> { if let Some(qualifier) = qualifier { self.field_with_qualified_name(qualifier, name) } else { @@ -404,7 +408,7 @@ impl DFSchema { &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &LogicalField)> { if let Some(qualifier) = qualifier { let idx = self .index_of_column_by_name(Some(qualifier), name) @@ -416,7 +420,7 @@ impl DFSchema { } /// Find all fields having the given qualifier - pub fn fields_with_qualified(&self, qualifier: &TableReference) -> Vec<&Field> { + pub fn fields_with_qualified(&self, qualifier: &TableReference) -> Vec<&LogicalField> { self.iter() .filter(|(q, _)| q.map(|q| q.eq(qualifier)).unwrap_or(false)) .map(|(_, f)| f.as_ref()) @@ -435,7 +439,7 @@ impl DFSchema { } /// Find all fields that match the given name - pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&Field> { + pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&LogicalField> { self.fields() .iter() .filter(|field| field.name() == name) @@ -447,7 +451,7 @@ impl DFSchema { pub fn qualified_fields_with_unqualified_name( &self, name: &str, - ) -> Vec<(Option<&TableReference>, &Field)> { + ) -> Vec<(Option<&TableReference>, &LogicalField)> { self.iter() .filter(|(_, field)| field.name() == name) .map(|(qualifier, field)| (qualifier, field.as_ref())) @@ -466,7 +470,7 @@ impl DFSchema { pub fn columns(&self) -> Vec { self.iter() .map(|(qualifier, field)| { - Column::new(qualifier.cloned(), field.name().clone()) + Column::new(qualifier.cloned(), field.name()) }) .collect() } @@ -475,7 +479,7 @@ impl DFSchema { pub fn qualified_field_with_unqualified_name( &self, name: &str, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &LogicalField)> { let matches = self.qualified_fields_with_unqualified_name(name); match matches.len() { 0 => Err(unqualified_field_not_found(name, self)), @@ -507,7 +511,7 @@ impl DFSchema { } /// Find the field with the given name - pub fn field_with_unqualified_name(&self, name: &str) -> Result<&Field> { + pub fn field_with_unqualified_name(&self, name: &str) -> Result<&LogicalField> { let matches = self.qualified_fields_with_unqualified_name(name); match matches.len() { 0 => Err(unqualified_field_not_found(name, self)), @@ -543,7 +547,7 @@ impl DFSchema { &self, qualifier: &TableReference, name: &str, - ) -> Result<&Field> { + ) -> Result<&LogicalField> { let idx = self .index_of_column_by_name(Some(qualifier), name) .ok_or_else(|| field_not_found(Some(qualifier.clone()), name, self))?; @@ -552,7 +556,7 @@ impl DFSchema { } /// Find the field with the given qualified column - pub fn field_from_column(&self, column: &Column) -> Result<&Field> { + pub fn field_from_column(&self, column: &Column) -> Result<&LogicalField> { match &column.relation { Some(r) => self.field_with_qualified_name(r, &column.name), None => self.field_with_unqualified_name(&column.name), @@ -563,7 +567,7 @@ impl DFSchema { pub fn qualified_field_from_column( &self, column: &Column, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &LogicalField)> { self.qualified_field_with_name(column.relation.as_ref(), &column.name) } @@ -658,7 +662,7 @@ impl DFSchema { self_fields.zip(other_fields).all(|((q1, f1), (q2, f2))| { q1 == q2 && f1.name() == f2.name() - && Self::datatype_is_semantically_equal(f1.data_type(), f2.data_type()) + && Self::datatype_is_semantically_equal(&f1.data_type().physical_type(), &f2.data_type().physical_type()) }) } @@ -666,40 +670,8 @@ impl DFSchema { /// than datatype_is_semantically_equal in that a Dictionary type is logically /// equal to a plain V type, but not semantically equal. Dictionary is also /// logically equal to Dictionary. - pub fn datatype_is_logically_equal(dt1: &DataType, dt2: &DataType) -> bool { - // check nested fields - match (dt1, dt2) { - (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { - v1.as_ref() == v2.as_ref() - } - (DataType::Dictionary(_, v1), othertype) => v1.as_ref() == othertype, - (othertype, DataType::Dictionary(_, v1)) => v1.as_ref() == othertype, - (DataType::List(f1), DataType::List(f2)) - | (DataType::LargeList(f1), DataType::LargeList(f2)) - | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) - | (DataType::Map(f1, _), DataType::Map(f2, _)) => { - Self::field_is_logically_equal(f1, f2) - } - (DataType::Struct(fields1), DataType::Struct(fields2)) => { - let iter1 = fields1.iter(); - let iter2 = fields2.iter(); - fields1.len() == fields2.len() && - // all fields have to be the same - iter1 - .zip(iter2) - .all(|(f1, f2)| Self::field_is_logically_equal(f1, f2)) - } - (DataType::Union(fields1, _), DataType::Union(fields2, _)) => { - let iter1 = fields1.iter(); - let iter2 = fields2.iter(); - fields1.len() == fields2.len() && - // all fields have to be the same - iter1 - .zip(iter2) - .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_logically_equal(f1, f2)) - } - _ => dt1 == dt2, - } + pub fn datatype_is_logically_equal(dt1: &LogicalType, dt2: &LogicalType) -> bool { + dt1 == dt2 } /// Returns true of two [`DataType`]s are semantically equal (same @@ -749,11 +721,6 @@ impl DFSchema { } } - fn field_is_logically_equal(f1: &Field, f2: &Field) -> bool { - f1.name() == f2.name() - && Self::datatype_is_logically_equal(f1.data_type(), f2.data_type()) - } - fn field_is_semantically_equal(f1: &Field, f2: &Field) -> bool { f1.name() == f2.name() && Self::datatype_is_semantically_equal(f1.data_type(), f2.data_type()) @@ -796,10 +763,10 @@ impl DFSchema { } /// Iterate over the qualifiers and fields in the DFSchema - pub fn iter(&self) -> impl Iterator, &FieldRef)> { + pub fn iter(&self) -> impl Iterator, &LogicalFieldRef)> { self.field_qualifiers .iter() - .zip(self.inner.fields().iter()) + .zip(self.inner.fields.iter()) .map(|(qualifier, field)| (qualifier.as_ref(), field)) } } @@ -807,7 +774,7 @@ impl DFSchema { impl From for Schema { /// Convert DFSchema into a Schema fn from(df_schema: DFSchema) -> Self { - let fields: Fields = df_schema.inner.fields.clone(); + let fields: Fields = df_schema.inner.fields.clone().into(); Schema::new_with_metadata(fields, df_schema.inner.metadata.clone()) } } @@ -815,23 +782,29 @@ impl From for Schema { impl From<&DFSchema> for Schema { /// Convert DFSchema reference into a Schema fn from(df_schema: &DFSchema) -> Self { - let fields: Fields = df_schema.inner.fields.clone(); + let fields: Fields = df_schema.inner.fields.clone().into(); Schema::new_with_metadata(fields, df_schema.inner.metadata.clone()) } } -/// Allow DFSchema to be converted into an Arrow `&Schema` -impl AsRef for DFSchema { - fn as_ref(&self) -> &Schema { - self.as_arrow() +/// Create a `DFSchema` from an Arrow schema +impl TryFrom for DFSchema { + type Error = DataFusionError; + fn try_from(schema: LogicalSchema) -> Result { + Self::try_from(Arc::new(schema)) } } -/// Allow DFSchema to be converted into an Arrow `&SchemaRef` (to clone, for -/// example) -impl AsRef for DFSchema { - fn as_ref(&self) -> &SchemaRef { - self.inner() +impl TryFrom for DFSchema { + type Error = DataFusionError; + fn try_from(schema: LogicalSchemaRef) -> Result { + let field_count = schema.fields.len(); + let dfschema = Self { + inner: schema, + field_qualifiers: vec![None; field_count], + functional_dependencies: FunctionalDependencies::empty(), + }; + Ok(dfschema) } } @@ -839,20 +812,14 @@ impl AsRef for DFSchema { impl TryFrom for DFSchema { type Error = DataFusionError; fn try_from(schema: Schema) -> Result { - Self::try_from(Arc::new(schema)) + Self::try_from(LogicalSchema::from(schema)) } } impl TryFrom for DFSchema { type Error = DataFusionError; fn try_from(schema: SchemaRef) -> Result { - let field_count = schema.fields.len(); - let dfschema = Self { - inner: schema, - field_qualifiers: vec![None; field_count], - functional_dependencies: FunctionalDependencies::empty(), - }; - Ok(dfschema) + Self::try_from(schema.as_ref().clone()) } } @@ -884,22 +851,22 @@ where } } -impl ToDFSchema for Schema { +impl ToDFSchema for LogicalSchema { fn to_dfschema(self) -> Result { DFSchema::try_from(self) } } -impl ToDFSchema for SchemaRef { +impl ToDFSchema for LogicalSchemaRef { fn to_dfschema(self) -> Result { DFSchema::try_from(self) } } -impl ToDFSchema for Vec { +impl ToDFSchema for Vec { fn to_dfschema(self) -> Result { let field_count = self.len(); - let schema = Schema { + let schema = LogicalSchema { fields: self.into(), metadata: HashMap::new(), }; @@ -936,13 +903,13 @@ pub trait ExprSchema: std::fmt::Debug { fn nullable(&self, col: &Column) -> Result; /// What is the datatype of this column? - fn data_type(&self, col: &Column) -> Result<&DataType>; + fn data_type(&self, col: &Column) -> Result<&LogicalType>; /// Returns the column's optional metadata. fn metadata(&self, col: &Column) -> Result<&HashMap>; /// Return the coulmn's datatype and nullability - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)>; + fn data_type_and_nullable(&self, col: &Column) -> Result<(&LogicalType, bool)>; } // Implement `ExprSchema` for `Arc` @@ -951,7 +918,7 @@ impl + std::fmt::Debug> ExprSchema for P { self.as_ref().nullable(col) } - fn data_type(&self, col: &Column) -> Result<&DataType> { + fn data_type(&self, col: &Column) -> Result<&LogicalType> { self.as_ref().data_type(col) } @@ -959,7 +926,7 @@ impl + std::fmt::Debug> ExprSchema for P { ExprSchema::metadata(self.as_ref(), col) } - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + fn data_type_and_nullable(&self, col: &Column) -> Result<(&LogicalType, bool)> { self.as_ref().data_type_and_nullable(col) } } @@ -969,7 +936,7 @@ impl ExprSchema for DFSchema { Ok(self.field_from_column(col)?.is_nullable()) } - fn data_type(&self, col: &Column) -> Result<&DataType> { + fn data_type(&self, col: &Column) -> Result<&LogicalType> { Ok(self.field_from_column(col)?.data_type()) } @@ -977,7 +944,7 @@ impl ExprSchema for DFSchema { Ok(self.field_from_column(col)?.metadata()) } - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + fn data_type_and_nullable(&self, col: &Column) -> Result<(&LogicalType, bool)> { let field = self.field_from_column(col)?; Ok((field.data_type(), field.is_nullable())) } @@ -1028,8 +995,8 @@ impl SchemaExt for Schema { .all(|(f1, f2)| { f1.name() == f2.name() && DFSchema::datatype_is_logically_equal( - f1.data_type(), - f2.data_type(), + &f1.data_type().into(), + &f2.data_type().into(), ) }) } @@ -1069,7 +1036,7 @@ mod tests { &Schema::new(vec![ Field::new("CapitalColumn", DataType::Boolean, true), Field::new("field.with.period", DataType::Boolean, true), - ]), + ]).into(), )?; // lookup with unqualified name "t1.c0" @@ -1099,9 +1066,9 @@ mod tests { fn test_from_field_specific_qualified_schema() -> Result<()> { let schema = DFSchema::from_field_specific_qualified_schema( vec![Some("t1".into()), None], - &Arc::new(Schema::new(vec![ - Field::new("c0", DataType::Boolean, true), - Field::new("c1", DataType::Boolean, true), + &Arc::new(LogicalSchema::new(vec![ + LogicalField::new("c0", LogicalType::Boolean, true), + LogicalField::new("c1", LogicalType::Boolean, true), ])), )?; assert_eq!("fields:[t1.c0, c1], metadata:{}", schema.to_string()); @@ -1114,9 +1081,9 @@ mod tests { vec![ ( Some("t0".into()), - Arc::new(Field::new("c0", DataType::Boolean, true)), + Arc::new(Field::new("c0", DataType::Boolean, true).into()), ), - (None, Arc::new(Field::new("c1", DataType::Boolean, true))), + (None, Arc::new(Field::new("c1", DataType::Boolean, true).into())), ], HashMap::new(), )?; @@ -1260,41 +1227,42 @@ mod tests { vec![Field::new("c0", DataType::Int64, true)], metadata.clone(), ); - let arrow_schema_ref = Arc::new(arrow_schema.clone()); + let logical_schema = LogicalSchema::from(arrow_schema); + let logical_schema_ref = Arc::new(logical_schema.clone()); let df_schema = DFSchema { - inner: arrow_schema_ref.clone(), - field_qualifiers: vec![None; arrow_schema_ref.fields.len()], + inner: logical_schema_ref.clone(), + field_qualifiers: vec![None; logical_schema_ref.fields.len()], functional_dependencies: FunctionalDependencies::empty(), }; let df_schema_ref = Arc::new(df_schema.clone()); { - let arrow_schema = arrow_schema.clone(); - let arrow_schema_ref = arrow_schema_ref.clone(); + let logical_schema = logical_schema.clone(); + let logical_schema_ref = logical_schema_ref.clone(); - assert_eq!(df_schema, arrow_schema.to_dfschema().unwrap()); - assert_eq!(df_schema, arrow_schema_ref.to_dfschema().unwrap()); + assert_eq!(df_schema, logical_schema.to_dfschema().unwrap()); + assert_eq!(df_schema, logical_schema_ref.to_dfschema().unwrap()); } { - let arrow_schema = arrow_schema.clone(); - let arrow_schema_ref = arrow_schema_ref.clone(); + let logical_schema = logical_schema.clone(); + let logical_schema_ref = logical_schema_ref.clone(); - assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); - assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); + assert_eq!(df_schema_ref, logical_schema.to_dfschema_ref().unwrap()); + assert_eq!(df_schema_ref, logical_schema_ref.to_dfschema_ref().unwrap()); } // Now, consume the refs - assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); - assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); + assert_eq!(df_schema_ref, logical_schema.to_dfschema_ref().unwrap()); + assert_eq!(df_schema_ref, logical_schema_ref.to_dfschema_ref().unwrap()); } - fn test_schema_1() -> Schema { + fn test_schema_1() -> LogicalSchema { Schema::new(vec![ Field::new("c0", DataType::Boolean, true), Field::new("c1", DataType::Boolean, true), - ]) + ]).into() } #[test] fn test_dfschema_to_schema_convertion() { @@ -1306,7 +1274,7 @@ mod tests { b_metadata.insert("key".to_string(), "value".to_string()); let b_field = Field::new("b", DataType::Int64, false).with_metadata(b_metadata); - let schema = Arc::new(Schema::new(vec![a_field, b_field])); + let schema = LogicalSchemaRef::new(Schema::new(vec![a_field, b_field]).into()); let df_schema = DFSchema { inner: schema.clone(), @@ -1350,10 +1318,10 @@ mod tests { Ok(()) } - fn test_schema_2() -> Schema { - Schema::new(vec![ - Field::new("c100", DataType::Boolean, true), - Field::new("c101", DataType::Boolean, true), + fn test_schema_2() -> LogicalSchema { + LogicalSchema::new(vec![ + LogicalField::new("c100", LogicalType::Boolean, true), + LogicalField::new("c101", LogicalType::Boolean, true), ]) } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index c275152642f0..7585f1a2c12a 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -42,6 +42,7 @@ pub mod stats; pub mod test_util; pub mod tree_node; pub mod utils; +pub mod logical_type; /// Reexport arrow crate pub use arrow; diff --git a/datafusion/common/src/logical_type/extension.rs b/datafusion/common/src/logical_type/extension.rs new file mode 100644 index 000000000000..0332de49f19e --- /dev/null +++ b/datafusion/common/src/logical_type/extension.rs @@ -0,0 +1,289 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_schema::{DataType, FieldRef, IntervalUnit, TimeUnit}; + +use crate::logical_type::type_signature::TypeSignature; +use crate::logical_type::LogicalType; + +pub type ExtensionTypeRef = Arc; + +pub trait ExtensionType: std::fmt::Debug { + fn display_name(&self) -> &str; + fn type_signature(&self) -> TypeSignature; + fn physical_type(&self) -> DataType; + + fn is_comparable(&self) -> bool; + fn is_orderable(&self) -> bool; + fn is_numeric(&self) -> bool; + fn is_floating(&self) -> bool; +} + +impl ExtensionType for LogicalType { + fn display_name(&self) -> &str { + use crate::logical_type::LogicalType::*; + match self { + Null => "Null", + Boolean => "Boolean", + Int8 => "Int8", + Int16 => "Int16", + Int32 => "Int32", + Int64 => "Int64", + UInt8 => "Uint8", + UInt16 => "Uint16", + UInt32 => "Uint32", + UInt64 => "Uint64", + Float16 => "Float16", + Float32 => "Float16", + Float64 => "Float64", + Date32 => "Date32", + Date64 => "Date64", + Time32(_) => "Time32", + Time64(_) => "Time64", + Timestamp(_, _) => "Timestamp", + Duration(_) => "Duration", + Interval(_) => "Interval", + Binary => "Binary", + FixedSizeBinary(_) => "FixedSizeBinary", + LargeBinary => "LargeBinary", + Utf8 => "Utf8", + LargeUtf8 => "LargeUtf8", + List(_) => "List", + FixedSizeList(_, _) => "FixedSizeList", + LargeList(_) => "LargeList", + Struct(_) => "Struct", + Map(_, _) => "Map", + Decimal128(_, _) => "Decimal128", + Decimal256(_, _) => "Decimal256", + Extension(ext) => ext.display_name(), + } + } + + fn type_signature(&self) -> TypeSignature { + use crate::logical_type::LogicalType::*; + fn time_unit_to_param(tu: &TimeUnit) -> &'static str { + match tu { + TimeUnit::Second => "second", + TimeUnit::Millisecond => "millisecond", + TimeUnit::Microsecond => "microsecond", + TimeUnit::Nanosecond => "nanosecond", + } + } + + match self { + Boolean => TypeSignature::new("boolean"), + Int32 => TypeSignature::new("int32"), + Int64 => TypeSignature::new("int64"), + UInt64 => TypeSignature::new("uint64"), + Float32 => TypeSignature::new("float32"), + Float64 => TypeSignature::new("float64"), + Timestamp(tu, zone) => { + let params = if let Some(zone) = zone { + vec![time_unit_to_param(tu).into(), zone.as_ref().into()] + } else { + vec![time_unit_to_param(tu).into()] + }; + + TypeSignature::new_with_params("timestamp", params) + } + Binary => TypeSignature::new("binary"), + Utf8 => TypeSignature::new("string"), + Struct(fields) => { + let params = fields.iter().map(|f| f.name().into()).collect(); + TypeSignature::new_with_params("struct", params) + } + Extension(ext) => ext.type_signature(), + Null => TypeSignature::new("null"), + Int8 => TypeSignature::new("int8"), + Int16 => TypeSignature::new("int16"), + UInt8 => TypeSignature::new("uint8"), + UInt16 => TypeSignature::new("uint16"), + UInt32 => TypeSignature::new("uint32"), + Float16 => TypeSignature::new("float16"), + Date32 => TypeSignature::new("date_32"), + Date64 => TypeSignature::new("date_64"), + Time32(tu) => TypeSignature::new_with_params( + "time_32", + vec![time_unit_to_param(tu).into()], + ), + Time64(tu) => TypeSignature::new_with_params( + "time_64", + vec![time_unit_to_param(tu).into()], + ), + Duration(tu) => TypeSignature::new_with_params( + "duration", + vec![time_unit_to_param(tu).into()], + ), + Interval(iu) => { + let iu = match iu { + IntervalUnit::YearMonth => "year_month", + IntervalUnit::DayTime => "day_time", + IntervalUnit::MonthDayNano => "month_day_nano", + }; + TypeSignature::new_with_params("interval", vec![iu.into()]) + } + FixedSizeBinary(size) => TypeSignature::new_with_params( + "fixed_size_binary", + vec![size.to_string().into()], + ), + LargeBinary => TypeSignature::new("large_binary"), + LargeUtf8 => TypeSignature::new("large_utf_8"), + List(f) => TypeSignature::new_with_params( + "list", + vec![f.data_type().display_name().into()], + ), + FixedSizeList(f, size) => TypeSignature::new_with_params( + "fixed_size_list", + vec![f.data_type().display_name().into(), size.to_string().into()], + ), + LargeList(f) => TypeSignature::new_with_params( + "large_list", + vec![f.data_type().display_name().into()], + ), + Map(f, b) => TypeSignature::new_with_params( + "map", + vec![f.data_type().display_name().into(), b.to_string().into()], + ), + Decimal128(a, b) => TypeSignature::new_with_params( + "decimal_128", + vec![a.to_string().into(), b.to_string().into()], + ), + Decimal256(a, b) => TypeSignature::new_with_params( + "decimal_256", + vec![a.to_string().into(), b.to_string().into()], + ), + } + } + + fn physical_type(&self) -> DataType { + use crate::logical_type::LogicalType::*; + match self { + Boolean => DataType::Boolean, + Int32 => DataType::Int32, + Int64 => DataType::Int64, + UInt64 => DataType::UInt64, + Float32 => DataType::Float32, + Float64 => DataType::Float64, + Timestamp(tu, zone) => DataType::Timestamp(tu.clone(), zone.clone()), + Binary => DataType::Binary, + Utf8 => DataType::Utf8, + Struct(fields) => { + let fields = fields + .iter() + .map(|f| FieldRef::new(f.as_ref().clone().into())) + .collect::>(); + DataType::Struct(fields.into()) + } + Extension(ext) => ext.physical_type(), + Null => DataType::Null, + Int8 => DataType::Int8, + Int16 => DataType::Int16, + UInt8 => DataType::UInt8, + UInt16 => DataType::UInt16, + UInt32 => DataType::UInt32, + Float16 => DataType::Float16, + Date32 => DataType::Date32, + Date64 => DataType::Date64, + Time32(tu) => DataType::Time32(tu.to_owned()), + Time64(tu) => DataType::Time64(tu.to_owned()), + Duration(tu) => DataType::Duration(tu.to_owned()), + Interval(iu) => DataType::Interval(iu.to_owned()), + FixedSizeBinary(size) => DataType::FixedSizeBinary(size.to_owned()), + LargeBinary => DataType::LargeBinary, + LargeUtf8 => DataType::LargeUtf8, + List(f) => DataType::List(FieldRef::new(f.as_ref().clone().into())), + FixedSizeList(f, size) => DataType::FixedSizeList(FieldRef::new(f.as_ref().clone().into()), size.to_owned()), + LargeList(f) => DataType::LargeList(FieldRef::new(f.as_ref().clone().into())), + Map(f, b) => DataType::Map(FieldRef::new(f.as_ref().clone().into()), b.to_owned()), + Decimal128(a, b) => DataType::Decimal128(a.to_owned(), b.to_owned()), + Decimal256(a, b) => DataType::Decimal256(a.to_owned(), b.to_owned()), + } + } + + fn is_comparable(&self) -> bool { + use crate::logical_type::LogicalType::*; + match self { + Null + | Boolean + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + | Duration(_) + | Interval(_) + | Binary + | FixedSizeBinary(_) + | LargeBinary + | Utf8 + | LargeUtf8 + | Decimal128(_, _) + | Decimal256(_, _) => true, + Extension(ext) => ext.is_comparable(), + _ => false, + } + } + + fn is_orderable(&self) -> bool { + todo!() + } + + #[inline] + fn is_numeric(&self) -> bool { + use crate::logical_type::LogicalType::*; + match self { + UInt8 + | UInt16 + | UInt32 + | UInt64 + | Int8 + | Int16 + | Int32 + | Int64 + | Float16 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) => true, + Extension(t) => t.is_numeric(), + _ => false, + } + } + + #[inline] + fn is_floating(&self) -> bool { + use crate::logical_type::LogicalType::*; + match self { + Float16 | Float32 | Float64 => true, + Extension(t) => t.is_floating(), + _ => false, + } + } +} diff --git a/datafusion/common/src/logical_type/field.rs b/datafusion/common/src/logical_type/field.rs new file mode 100644 index 000000000000..3cd6f73fe374 --- /dev/null +++ b/datafusion/common/src/logical_type/field.rs @@ -0,0 +1,163 @@ +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; +use arrow_schema::{DataType, Field}; +use crate::logical_type::extension::ExtensionType; +use crate::logical_type::fields::LogicalFields; +use crate::logical_type::LogicalType; +use crate::logical_type::type_signature::TypeSignature; + +pub type LogicalFieldRef = Arc; + +#[derive(Debug, Clone)] +pub struct LogicalField { + name: String, + data_type: LogicalType, + nullable: bool, + metadata: HashMap, +} + +impl From<&Field> for LogicalField { + fn from(value: &Field) -> Self { + Self { + name: value.name().clone(), + data_type: value.data_type().clone().into(), + nullable: value.is_nullable(), + metadata: value.metadata().clone() + } + } +} + +impl From for LogicalField { + fn from(value: Field) -> Self { + Self::from(&value) + } +} + +impl Into for LogicalField { + fn into(self) -> Field { + Field::new(self.name, self.data_type.physical_type(), self.nullable).with_metadata(self.metadata) + } +} + +impl PartialEq for LogicalField { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.data_type == other.data_type + && self.nullable == other.nullable + && self.metadata == other.metadata + } +} + +impl Eq for LogicalField {} + +impl Hash for LogicalField { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.data_type.hash(state); + self.nullable.hash(state); + + // ensure deterministic key order + let mut keys: Vec<&String> = self.metadata.keys().collect(); + keys.sort(); + for k in keys { + k.hash(state); + self.metadata.get(k).expect("key valid").hash(state); + } + } +} + +impl ExtensionType for LogicalField { + fn display_name(&self) -> &str { + &self.name + } + + fn type_signature(&self) -> TypeSignature { + TypeSignature::new(self.name()) + } + + fn physical_type(&self) -> DataType { + self.data_type.physical_type() + } + + fn is_comparable(&self) -> bool { + self.data_type.is_comparable() + } + + fn is_orderable(&self) -> bool { + self.data_type.is_orderable() + } + + fn is_numeric(&self) -> bool { + self.data_type.is_numeric() + } + + fn is_floating(&self) -> bool { + self.data_type.is_floating() + } +} + +impl LogicalField { + pub fn new(name: impl Into, data_type: LogicalType, nullable: bool) -> Self { + LogicalField { + name: name.into(), + data_type, + nullable, + metadata: HashMap::default(), + } + } + + pub fn new_list_field(data_type: LogicalType, nullable: bool) -> Self { + Self::new("item", data_type, nullable) + } + + pub fn new_struct(name: impl Into, fields: impl Into, nullable: bool) -> Self { + Self::new(name, LogicalType::Struct(fields.into()), nullable) + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn data_type(&self) -> &LogicalType { + &self.data_type + } + + pub fn is_nullable(&self) -> bool { + self.nullable + } + + pub fn metadata(&self) -> &HashMap { + &self.metadata + } + + #[inline] + pub fn with_name(mut self, name: impl Into) -> Self { + self.name = name.into(); + self + } + + #[inline] + pub fn with_nullable(mut self, nullable: bool) -> Self { + self.nullable = nullable; + self + } + + #[inline] + pub fn with_metadata(mut self, metadata: HashMap) -> Self { + self.metadata = metadata; + self + } + + #[inline] + pub fn with_data_type(mut self, data_type: LogicalType) -> Self { + self.data_type = data_type; + self + } +} + +impl std::fmt::Display for LogicalField { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{self:?}") + } +} diff --git a/datafusion/common/src/logical_type/fields.rs b/datafusion/common/src/logical_type/fields.rs new file mode 100644 index 000000000000..456835bc8d9b --- /dev/null +++ b/datafusion/common/src/logical_type/fields.rs @@ -0,0 +1,94 @@ +use std::ops::Deref; +use std::sync::Arc; +use arrow_schema::{Field, Fields}; +use crate::logical_type::field::{LogicalField, LogicalFieldRef}; + +#[derive(Clone, Eq, PartialEq, Hash)] +pub struct LogicalFields(Arc<[LogicalFieldRef]>); + +impl std::fmt::Debug for LogicalFields { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.as_ref().fmt(f) + } +} + +impl From for LogicalFields { + fn from(value: Fields) -> Self { + Self(value.into_iter().map(|v| LogicalFieldRef::new(LogicalField::from(v.as_ref()))).collect()) + } +} + +impl Into for LogicalFields { + fn into(self) -> Fields { + Fields::from( + self.iter() + .map(|f| f.as_ref().clone().into()) + .collect::>() + ) + } +} + +impl Default for LogicalFields { + fn default() -> Self { + Self::empty() + } +} + +impl FromIterator for LogicalFields { + fn from_iter>(iter: T) -> Self { + iter.into_iter().map(Arc::new).collect() + } +} + +impl FromIterator for LogicalFields { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl From> for LogicalFields { + fn from(value: Vec) -> Self { + value.into_iter().collect() + } +} + +impl From> for LogicalFields { + fn from(value: Vec) -> Self { + Self(value.into()) + } +} + +impl From<&[LogicalFieldRef]> for LogicalFields { + fn from(value: &[LogicalFieldRef]) -> Self { + Self(value.into()) + } +} + +impl From<[LogicalFieldRef; N]> for LogicalFields { + fn from(value: [LogicalFieldRef; N]) -> Self { + Self(Arc::new(value)) + } +} + +impl Deref for LogicalFields { + type Target = [LogicalFieldRef]; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl<'a> IntoIterator for &'a LogicalFields { + type Item = &'a LogicalFieldRef; + type IntoIter = std::slice::Iter<'a, LogicalFieldRef>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl LogicalFields { + pub fn empty() -> Self { + Self(Arc::new([])) + } +} diff --git a/datafusion/common/src/logical_type/mod.rs b/datafusion/common/src/logical_type/mod.rs new file mode 100644 index 000000000000..fd464411a4ef --- /dev/null +++ b/datafusion/common/src/logical_type/mod.rs @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{fmt::Display, sync::Arc}; + +use arrow_schema::{DataType, IntervalUnit, TimeUnit}; + +use crate::logical_type::extension::{ExtensionType, ExtensionTypeRef}; +use crate::logical_type::field::{LogicalField, LogicalFieldRef}; +use crate::logical_type::fields::LogicalFields; + +pub mod type_signature; +pub mod extension; +pub mod registry; +pub mod schema; +pub mod field; +pub mod fields; + +#[derive(Clone, Debug)] +pub enum LogicalType { + Null, + Boolean, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + Float16, + Float32, + Float64, + Date32, + Date64, + Time32(TimeUnit), + Time64(TimeUnit), + Timestamp(TimeUnit, Option>), + Duration(TimeUnit), + Interval(IntervalUnit), + Binary, + FixedSizeBinary(i32), + LargeBinary, + Utf8, + LargeUtf8, + List(LogicalFieldRef), + FixedSizeList(LogicalFieldRef, i32), + LargeList(LogicalFieldRef), + Struct(LogicalFields), + Map(LogicalFieldRef, bool), + Decimal128(u8, i8), + Decimal256(u8, i8), + Extension(ExtensionTypeRef), + // TODO: tbd union +} + +impl PartialEq for LogicalType { + fn eq(&self, other: &Self) -> bool { + self.type_signature() == other.type_signature() + } +} + +impl Eq for LogicalType {} + +impl std::hash::Hash for LogicalType { + fn hash(&self, state: &mut H) { + self.type_signature().hash(state) + } +} + +impl Display for LogicalType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl From<&DataType> for LogicalType { + fn from(value: &DataType) -> Self { + value.clone().into() + } +} + +impl From for LogicalType { + fn from(value: DataType) -> Self { + match value { + DataType::Null => LogicalType::Null, + DataType::Boolean => LogicalType::Boolean, + DataType::Int8 => LogicalType::Int8, + DataType::Int16 => LogicalType::Int16, + DataType::Int32 => LogicalType::Int32, + DataType::Int64 => LogicalType::Int64, + DataType::UInt8 => LogicalType::UInt8, + DataType::UInt16 => LogicalType::UInt16, + DataType::UInt32 => LogicalType::UInt32, + DataType::UInt64 => LogicalType::UInt64, + DataType::Float16 => LogicalType::Float16, + DataType::Float32 => LogicalType::Float32, + DataType::Float64 => LogicalType::Float64, + DataType::Timestamp(tu, z) => LogicalType::Timestamp(tu, z), + DataType::Date32 => LogicalType::Date32, + DataType::Date64 => LogicalType::Date64, + DataType::Time32(tu) => LogicalType::Time32(tu), + DataType::Time64(tu) => LogicalType::Time64(tu), + DataType::Duration(tu) => LogicalType::Duration(tu), + DataType::Interval(iu) => LogicalType::Interval(iu), + DataType::Binary | DataType::BinaryView => LogicalType::Binary, + DataType::FixedSizeBinary(len) => LogicalType::FixedSizeBinary(len), + DataType::LargeBinary => LogicalType::LargeBinary, + DataType::Utf8 | DataType::Utf8View => LogicalType::Utf8, + DataType::LargeUtf8 => LogicalType::LargeUtf8, + DataType::List(f) | DataType::ListView(f) => LogicalType::List(LogicalFieldRef::new(f.as_ref().into())), + DataType::FixedSizeList(f, len) => LogicalType::FixedSizeList(LogicalFieldRef::new(f.as_ref().into()), len), + DataType::LargeList(f) | DataType::LargeListView(f) => LogicalType::LargeList(LogicalFieldRef::new(f.as_ref().into())), + DataType::Struct(fields) => LogicalType::Struct(fields.into()), + DataType::Dictionary(_, dt) => dt.as_ref().into(), + DataType::Decimal128(p, s) => LogicalType::Decimal128(p, s), + DataType::Decimal256(p, s) => LogicalType::Decimal256(p, s), + DataType::Map(f, sorted) => LogicalType::Map(LogicalFieldRef::new(f.as_ref().into()), sorted), + DataType::RunEndEncoded(_, f) => f.data_type().into(), + DataType::Union(_, _) => unimplemented!(), // TODO: tbd union + } + } +} + +impl LogicalType { + + pub fn new_list(data_type: LogicalType, nullable: bool) -> Self { + LogicalType::List(Arc::new(LogicalField::new_list_field(data_type, nullable))) + } + + pub fn new_large_list(data_type: LogicalType, nullable: bool) -> Self { + LogicalType::LargeList(Arc::new(LogicalField::new_list_field(data_type, nullable))) + } + + pub fn new_fixed_size_list(data_type: LogicalType, size: i32, nullable: bool) -> Self { + LogicalType::FixedSizeList(Arc::new(LogicalField::new_list_field(data_type, nullable)), size) + } + + pub fn is_floating(&self) -> bool { + matches!(self, Self::Float16 | Self::Float32 | Self::Float64) + } +} diff --git a/datafusion/common/src/logical_type/registry.rs b/datafusion/common/src/logical_type/registry.rs new file mode 100644 index 000000000000..63d25bb23e91 --- /dev/null +++ b/datafusion/common/src/logical_type/registry.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use crate::error::_plan_datafusion_err; +use crate::logical_type::extension::ExtensionTypeRef; +use crate::logical_type::type_signature::TypeSignature; + +pub trait TypeRegistry { + fn register_data_type( + &mut self, + extension_type: ExtensionTypeRef, + ) -> crate::Result>; + + fn data_type(&self, signature: &TypeSignature) -> crate::Result; +} + + +#[derive(Default, Debug)] +pub struct MemoryTypeRegistry { + types: HashMap, +} + +impl TypeRegistry for MemoryTypeRegistry { + fn register_data_type(&mut self, extension_type: ExtensionTypeRef) -> crate::Result> { + Ok(self.types.insert(extension_type.type_signature(), extension_type)) + } + + fn data_type(&self, signature: &TypeSignature) -> crate::Result { + self.types + .get(signature) + .cloned() + .ok_or_else(|| _plan_datafusion_err!("Type with signature {signature:?} not found")) + } +} + diff --git a/datafusion/common/src/logical_type/schema.rs b/datafusion/common/src/logical_type/schema.rs new file mode 100644 index 000000000000..f0c2854b4e64 --- /dev/null +++ b/datafusion/common/src/logical_type/schema.rs @@ -0,0 +1,159 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use arrow_schema::Schema; + +use crate::logical_type::field::{LogicalField, LogicalFieldRef}; +use crate::logical_type::fields::LogicalFields; + +#[derive(Debug, Default)] +pub struct LogicalSchemaBuilder { + fields: Vec, + metadata: HashMap, +} + +impl LogicalSchemaBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { + fields: Vec::with_capacity(capacity), + metadata: Default::default(), + } + } + + pub fn push(&mut self, field: impl Into) { + self.fields.push(field.into()) + } + + pub fn remove(&mut self, idx: usize) -> LogicalFieldRef { + self.fields.remove(idx) + } + + pub fn field(&mut self, idx: usize) -> &LogicalFieldRef { + &mut self.fields[idx] + } + + pub fn field_mut(&mut self, idx: usize) -> &mut LogicalFieldRef { + &mut self.fields[idx] + } + + pub fn metadata(&mut self) -> &HashMap { + &self.metadata + } + + pub fn metadata_mut(&mut self) -> &mut HashMap { + &mut self.metadata + } + + pub fn reverse(&mut self) { + self.fields.reverse(); + } + + pub fn finish(self) -> LogicalSchema { + LogicalSchema { + fields: self.fields.into(), + metadata: self.metadata, + } + } +} + +impl From<&LogicalFields> for LogicalSchemaBuilder { + fn from(value: &LogicalFields) -> Self { + Self { + fields: value.to_vec(), + metadata: Default::default(), + } + } +} + +impl From for LogicalSchemaBuilder { + fn from(value: LogicalFields) -> Self { + Self { + fields: value.to_vec(), + metadata: Default::default(), + } + } +} + +impl From<&LogicalSchema> for LogicalSchemaBuilder { + fn from(value: &LogicalSchema) -> Self { + Self::from(value.clone()) + } +} + +impl From for LogicalSchemaBuilder { + fn from(value: LogicalSchema) -> Self { + Self { + fields: value.fields.to_vec(), + metadata: value.metadata, + } + } +} + +impl Extend for LogicalSchemaBuilder { + fn extend>(&mut self, iter: T) { + let iter = iter.into_iter(); + self.fields.reserve(iter.size_hint().0); + for f in iter { + self.push(f) + } + } +} + +impl Extend for LogicalSchemaBuilder { + fn extend>(&mut self, iter: T) { + let iter = iter.into_iter(); + self.fields.reserve(iter.size_hint().0); + for f in iter { + self.push(f) + } + } +} + +pub type LogicalSchemaRef = Arc; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LogicalSchema { + pub fields: LogicalFields, + pub metadata: HashMap, +} + +impl From for LogicalSchema { + fn from(value: Schema) -> Self { + Self { + fields: value.fields.into(), + metadata: value.metadata, + } + } +} + +impl LogicalSchema { + pub fn new(fields: impl Into) -> Self { + Self::new_with_metadata(fields, HashMap::new()) + } + + #[inline] + pub fn new_with_metadata(fields: impl Into, metadata: HashMap) -> Self { + Self { + fields: fields.into(), + metadata, + } + } + + #[inline] + pub fn with_metadata(mut self, metadata: HashMap) -> Self { + self.metadata = metadata; + self + } + + pub fn metadata(&self) -> &HashMap { + &self.metadata + } + + pub fn field(&self, i: usize) -> &LogicalFieldRef { + &self.fields[i] + } +} \ No newline at end of file diff --git a/datafusion/common/src/logical_type/type_signature.rs b/datafusion/common/src/logical_type/type_signature.rs new file mode 100644 index 000000000000..4b46d040e5c4 --- /dev/null +++ b/datafusion/common/src/logical_type/type_signature.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TypeSignature { + // **func_name**(p1, p2) + name: Arc, + // func_name(**p1**, **p2**) + params: Vec>, +} + +impl TypeSignature { + pub fn new(name: impl Into>) -> Self { + Self::new_with_params(name, vec![]) + } + + pub fn new_with_params( + name: impl Into>, + params: Vec>, + ) -> Self { + Self { + name: name.into(), + params, + } + } +} diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index 8d61bad97b9f..e389d520d834 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -17,8 +17,8 @@ use crate::error::{_plan_datafusion_err, _plan_err}; use crate::{Result, ScalarValue}; -use arrow_schema::DataType; use std::collections::HashMap; +use crate::logical_type::LogicalType; /// The parameter value corresponding to the placeholder #[derive(Debug, Clone)] @@ -31,7 +31,7 @@ pub enum ParamValues { impl ParamValues { /// Verify parameter list length and type - pub fn verify(&self, expect: &[DataType]) -> Result<()> { + pub fn verify(&self, expect: &[LogicalType]) -> Result<()> { match self { ParamValues::List(list) => { // Verify if the number of params matches the number of values @@ -46,7 +46,7 @@ impl ParamValues { // Verify if the types of the params matches the types of the values let iter = expect.iter().zip(list.iter()); for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.data_type() { + if *param_type != value.data_type().into() { return _plan_err!( "Expected parameter of type {:?}, got {:?} at index {}", param_type, diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5b9c4a223de6..67e959984b3d 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -54,10 +54,11 @@ use arrow::{ }, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer}; -use arrow_schema::{UnionFields, UnionMode}; +use arrow_schema::{FieldRef, UnionFields, UnionMode}; use half::f16; pub use struct_builder::ScalarStructBuilder; +use crate::logical_type::LogicalType; /// A dynamically typed, nullable single value. /// @@ -3270,6 +3271,118 @@ impl TryFrom<&DataType> for ScalarValue { } } + +impl TryFrom for ScalarValue { + type Error = DataFusionError; + + /// Create a Null instance of ScalarValue for this datatype + fn try_from(datatype: LogicalType) -> Result { + (&datatype).try_into() + } +} + +impl TryFrom<&LogicalType> for ScalarValue { + type Error = DataFusionError; + + /// Create a Null instance of ScalarValue for this datatype + fn try_from(data_type: &LogicalType) -> Result { + Ok(match data_type { + LogicalType::Boolean => ScalarValue::Boolean(None), + LogicalType::Float16 => ScalarValue::Float16(None), + LogicalType::Float64 => ScalarValue::Float64(None), + LogicalType::Float32 => ScalarValue::Float32(None), + LogicalType::Int8 => ScalarValue::Int8(None), + LogicalType::Int16 => ScalarValue::Int16(None), + LogicalType::Int32 => ScalarValue::Int32(None), + LogicalType::Int64 => ScalarValue::Int64(None), + LogicalType::UInt8 => ScalarValue::UInt8(None), + LogicalType::UInt16 => ScalarValue::UInt16(None), + LogicalType::UInt32 => ScalarValue::UInt32(None), + LogicalType::UInt64 => ScalarValue::UInt64(None), + LogicalType::Decimal128(precision, scale) => { + ScalarValue::Decimal128(None, *precision, *scale) + } + LogicalType::Decimal256(precision, scale) => { + ScalarValue::Decimal256(None, *precision, *scale) + } + LogicalType::Utf8 => ScalarValue::Utf8(None), + LogicalType::LargeUtf8 => ScalarValue::LargeUtf8(None), + LogicalType::Binary => ScalarValue::Binary(None), + LogicalType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None), + LogicalType::LargeBinary => ScalarValue::LargeBinary(None), + LogicalType::Date32 => ScalarValue::Date32(None), + LogicalType::Date64 => ScalarValue::Date64(None), + LogicalType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(None), + LogicalType::Time32(TimeUnit::Millisecond) => { + ScalarValue::Time32Millisecond(None) + } + LogicalType::Time64(TimeUnit::Microsecond) => { + ScalarValue::Time64Microsecond(None) + } + LogicalType::Time64(TimeUnit::Nanosecond) => ScalarValue::Time64Nanosecond(None), + LogicalType::Timestamp(TimeUnit::Second, tz_opt) => { + ScalarValue::TimestampSecond(None, tz_opt.clone()) + } + LogicalType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + ScalarValue::TimestampMillisecond(None, tz_opt.clone()) + } + LogicalType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) + } + LogicalType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + ScalarValue::TimestampNanosecond(None, tz_opt.clone()) + } + LogicalType::Interval(IntervalUnit::YearMonth) => { + ScalarValue::IntervalYearMonth(None) + } + LogicalType::Interval(IntervalUnit::DayTime) => { + ScalarValue::IntervalDayTime(None) + } + LogicalType::Interval(IntervalUnit::MonthDayNano) => { + ScalarValue::IntervalMonthDayNano(None) + } + LogicalType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None), + LogicalType::Duration(TimeUnit::Millisecond) => { + ScalarValue::DurationMillisecond(None) + } + LogicalType::Duration(TimeUnit::Microsecond) => { + ScalarValue::DurationMicrosecond(None) + } + LogicalType::Duration(TimeUnit::Nanosecond) => { + ScalarValue::DurationNanosecond(None) + } + // `ScalaValue::List` contains single element `ListArray`. + LogicalType::List(field_ref) => ScalarValue::List(Arc::new( + GenericListArray::new_null(FieldRef::new(field_ref.as_ref().clone().into()), 1), + )), + // `ScalarValue::LargeList` contains single element `LargeListArray`. + LogicalType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new( + GenericListArray::new_null(FieldRef::new(field_ref.as_ref().clone().into()), 1), + )), + // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. + LogicalType::FixedSizeList(field_ref, fixed_length) => { + ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null( + FieldRef::new(field_ref.as_ref().clone().into()), + *fixed_length, + 1, + ))) + } + LogicalType::Struct(fields) => ScalarValue::Struct( + new_null_array(&DataType::Struct(fields.clone().into()), 1) + .as_struct() + .to_owned() + .into(), + ), + LogicalType::Null => ScalarValue::Null, + _ => { + return _not_impl_err!( + "Can't create a scalar from data_type \"{data_type:?}\"" + ); + } + }) + } +} + macro_rules! format_option { ($F:expr, $EXPR:expr) => {{ match $EXPR { diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 8e55da8c3ad0..afcfe6615c0a 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -58,6 +58,8 @@ use datafusion_expr::{ use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum}; use async_trait::async_trait; +use datafusion_common::logical_type::extension::ExtensionType; +use datafusion_common::logical_type::LogicalType; /// Contains options that control how data is /// written out from a DataFrame @@ -666,7 +668,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| { - !matches!(f.data_type(), DataType::Binary | DataType::Boolean) + !matches!(f.data_type(), LogicalType::Binary | LogicalType::Boolean) }) .map(|f| min(col(f.name())).alias(f.name())) .collect::>(), @@ -677,7 +679,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| { - !matches!(f.data_type(), DataType::Binary | DataType::Boolean) + !matches!(f.data_type(), LogicalType::Binary | LogicalType::Boolean) }) .map(|f| max(col(f.name())).alias(f.name())) .collect::>(), @@ -1285,7 +1287,7 @@ impl DataFrame { let plan = LogicalPlanBuilder::insert_into( self.plan, table_name.to_owned(), - &arrow_schema, + &arrow_schema.into(), write_options.overwrite, )? .build()?; @@ -1695,6 +1697,7 @@ mod tests { use arrow::array::{self, Int32Array}; use datafusion_common::{Constraint, Constraints}; + use datafusion_common::logical_type::LogicalType; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, @@ -2362,7 +2365,7 @@ mod tests { let field = df.schema().field(0); // There are two columns named 'c', one from the input of the aggregate and the other from the output. // Select should return the column from the output of the aggregate, which is a list. - assert!(matches!(field.data_type(), DataType::List(_))); + assert!(matches!(field.data_type(), LogicalType::List(_))); Ok(()) } @@ -3169,7 +3172,7 @@ mod tests { .await? .select_columns(&["c2", "c3"])? .limit(0, Some(1))? - .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + .with_column("sum", cast(col("c2") + col("c3"), LogicalType::Int64))?; let df_results = df.clone().collect().await?; df.clone().show().await?; @@ -3271,7 +3274,7 @@ mod tests { .await? .select_columns(&["c2", "c3"])? .limit(0, Some(1))? - .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + .with_column("sum", cast(col("c2") + col("c3"), LogicalType::Int64))?; let cached_df = df.clone().cache().await?; diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 822a66783819..8236f8f06c71 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -44,6 +44,7 @@ use datafusion_expr::{Expr, Volatility}; use datafusion_physical_expr::create_physical_expr; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; +use datafusion_common::logical_type::field::LogicalField; /// Check whether the given expression can be resolved using only the columns `col_names`. /// This means that if this function returns true: @@ -264,7 +265,7 @@ async fn prune_partitions( let df_schema = DFSchema::from_unqualifed_fields( partition_cols .iter() - .map(|(n, d)| Field::new(n, d.clone(), true)) + .map(|(n, d)| LogicalField::new(n, d.clone().into(), true)) .collect(), Default::default(), )?; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 74aca82b3ee6..ffa0013b31cb 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -55,6 +55,7 @@ use async_trait::async_trait; use futures::{future, stream, StreamExt, TryStreamExt}; use itertools::Itertools; use object_store::ObjectStore; +use datafusion_common::logical_type::schema::LogicalSchema; /// Configuration for creating a [`ListingTable`] #[derive(Debug, Clone)] @@ -788,7 +789,7 @@ impl TableProvider for ListingTable { let filters = if let Some(expr) = conjunction(filters.to_vec()) { // NOTE: Use the table schema (NOT file schema) here because `expr` may contain references to partition columns. - let table_df_schema = self.table_schema.as_ref().clone().to_dfschema()?; + let table_df_schema = LogicalSchema::from(self.table_schema.as_ref().clone()).to_dfschema()?; let filters = create_physical_expr(&expr, &table_df_schema, state.execution_props())?; Some(filters) @@ -1877,7 +1878,7 @@ mod tests { // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema.as_ref().clone().into(), false)?.build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index aab42285a0b2..3126e8dee5de 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -624,7 +624,7 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema.as_ref().clone().into(), false)?.build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index f5d3c7a6410d..f09e77a9a8be 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -626,7 +626,6 @@ fn create_output_array( #[cfg(test)] mod tests { use arrow_array::Int32Array; - use super::*; use crate::{test::columns, test_util::aggr_test_schema}; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index f9cce5f783ff..b73bbd8dad66 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -425,6 +425,9 @@ mod test { use parquet::file::reader::{FileReader, SerializedFileReader}; use rand::prelude::*; + use datafusion_common::logical_type::LogicalType; + + // We should ignore predicate that read non-primitive columns #[test] fn test_filter_candidate_builder_ignore_complex_types() { @@ -471,10 +474,10 @@ mod test { ]); // The parquet file with `file_schema` just has `bigint_col` and `float_col` column, and don't have the `int_col` - let expr = col("bigint_col").eq(cast(col("int_col"), DataType::Int64)); + let expr = col("bigint_col").eq(cast(col("int_col"), LogicalType::Int64)); let expr = logical2physical(&expr, &table_schema); let expected_candidate_expr = - col("bigint_col").eq(cast(lit(ScalarValue::Int32(None)), DataType::Int64)); + col("bigint_col").eq(cast(lit(ScalarValue::Int32(None)), LogicalType::Int64)); let expected_candidate_expr = logical2physical(&expected_candidate_expr, &table_schema); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 9bc79805746f..87d64c4a43ba 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -411,9 +411,8 @@ mod tests { use crate::datasource::physical_plan::parquet::reader::ParquetFileReader; use crate::physical_plan::metrics::ExecutionPlanMetricsSet; - use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::{DataType, Field}; - use datafusion_common::Result; + use datafusion_common::{Result, logical_type::LogicalType::*}; use datafusion_expr::{cast, col, lit, Expr}; use datafusion_physical_expr::planner::logical2physical; @@ -819,7 +818,7 @@ mod tests { .with_scale(0) .with_precision(9); let schema_descr = get_test_schema_descr(vec![field]); - let expr = cast(col("c1"), DataType::Decimal128(11, 2)).gt(cast( + let expr = cast(col("c1"), Decimal128(11, 2)).gt(cast( lit(ScalarValue::Decimal128(Some(500), 5, 2)), Decimal128(11, 2), )); @@ -936,7 +935,7 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); @@ -1010,7 +1009,7 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 2b7867e72046..6068b89cb4be 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -42,7 +42,7 @@ use crate::physical_optimizer::optimizer::PhysicalOptimizer; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; use crate::{functions, functions_aggregate}; -use arrow_schema::{DataType, SchemaRef}; +use arrow_schema::{SchemaRef}; use async_trait::async_trait; use chrono::{DateTime, Utc}; use datafusion_common::alias::AliasGenerator; @@ -84,6 +84,7 @@ use std::fmt::Debug; use std::sync::Arc; use url::Url; use uuid::Uuid; +use datafusion_common::logical_type::LogicalType; /// Execution context for registering data sources and executing queries. /// See [`SessionContext`] for a higher level API. @@ -1001,7 +1002,7 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { self.state.window_functions().get(name).cloned() } - fn get_variable_type(&self, variable_names: &[String]) -> Option { + fn get_variable_type(&self, variable_names: &[String]) -> Option { if variable_names.is_empty() { return None; } @@ -1218,7 +1219,7 @@ impl<'a> SessionSimplifyProvider<'a> { impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> { fn is_boolean_type(&self, expr: &Expr) -> datafusion_common::Result { - Ok(expr.get_type(self.df_schema)? == DataType::Boolean) + Ok(expr.get_type(self.df_schema)? == LogicalType::Boolean) } fn nullable(&self, expr: &Expr) -> datafusion_common::Result { @@ -1229,7 +1230,7 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> { self.state.execution_props() } - fn get_data_type(&self, expr: &Expr) -> datafusion_common::Result { + fn get_data_type(&self, expr: &Expr) -> datafusion_common::Result { expr.get_type(self.df_schema) } } diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index e8f2f34abda0..8624a80bfce3 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -1571,6 +1571,7 @@ mod tests { use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; use datafusion_physical_expr::planner::logical2physical; + use datafusion_common::logical_type::LogicalType; #[derive(Debug, Default)] /// Mock statistic provider for tests @@ -2607,13 +2608,13 @@ mod tests { // test cast(c1 as int64) = 1 // test column on the left - let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); + let expr = cast(col("c1"), LogicalType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right - let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), DataType::Int64)); + let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), LogicalType::Int64)); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2625,14 +2626,14 @@ mod tests { // test column on the left let expr = - try_cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(1)))); + try_cast(col("c1"), LogicalType::Int64).gt(lit(ScalarValue::Int64(Some(1)))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = - lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), DataType::Int64)); + lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), LogicalType::Int64)); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2645,7 +2646,7 @@ mod tests { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); // test cast(c1 as int64) in int64(1, 2, 3) let expr = Expr::InList(InList::new( - Box::new(cast(col("c1"), DataType::Int64)), + Box::new(cast(col("c1"), LogicalType::Int64)), vec![ lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Int64(Some(2))), @@ -2670,7 +2671,7 @@ mod tests { assert_eq!(predicate_expr.to_string(), expected_expr); let expr = Expr::InList(InList::new( - Box::new(cast(col("c1"), DataType::Int64)), + Box::new(cast(col("c1"), LogicalType::Int64)), vec![ lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Int64(Some(2))), @@ -2724,7 +2725,7 @@ mod tests { prune_with_expr( // with cast column to other type - cast(col("s1"), DataType::Decimal128(14, 3)) + cast(col("s1"), LogicalType::Decimal128(14, 3)) .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), &schema, &TestStatistics::new().with( @@ -2739,7 +2740,7 @@ mod tests { prune_with_expr( // with try cast column to other type - try_cast(col("s1"), DataType::Decimal128(14, 3)) + try_cast(col("s1"), LogicalType::Decimal128(14, 3)) .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), &schema, &TestStatistics::new().with( @@ -2826,7 +2827,7 @@ mod tests { prune_with_expr( // filter with cast - cast(col("s2"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(5)))), + cast(col("s2"), LogicalType::Int64).gt(lit(ScalarValue::Int64(Some(5)))), &schema, &statistics, &[false, true, true, true], @@ -3054,7 +3055,7 @@ mod tests { prune_with_expr( // cast(i as utf8) <= 0 - cast(col("i"), DataType::Utf8).lt_eq(lit("0")), + cast(col("i"), LogicalType::Utf8).lt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3062,7 +3063,7 @@ mod tests { prune_with_expr( // try_cast(i as utf8) <= 0 - try_cast(col("i"), DataType::Utf8).lt_eq(lit("0")), + try_cast(col("i"), LogicalType::Utf8).lt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3070,7 +3071,7 @@ mod tests { prune_with_expr( // cast(-i as utf8) >= 0 - cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + cast(Expr::Negative(Box::new(col("i"))), LogicalType::Utf8).gt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3078,7 +3079,7 @@ mod tests { prune_with_expr( // try_cast(-i as utf8) >= 0 - try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + try_cast(Expr::Negative(Box::new(col("i"))), LogicalType::Utf8).gt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3119,14 +3120,14 @@ mod tests { let expected_ret = &[true, false, false, true, false]; prune_with_expr( - cast(col("i"), DataType::Int64).eq(lit(0i64)), + cast(col("i"), LogicalType::Int64).eq(lit(0i64)), &schema, &statistics, expected_ret, ); prune_with_expr( - try_cast(col("i"), DataType::Int64).eq(lit(0i64)), + try_cast(col("i"), LogicalType::Int64).eq(lit(0i64)), &schema, &statistics, expected_ret, @@ -3149,7 +3150,7 @@ mod tests { let expected_ret = &[true, true, true, true, true]; prune_with_expr( - cast(col("i"), DataType::Utf8).eq(lit("0")), + cast(col("i"), LogicalType::Utf8).eq(lit("0")), &schema, &statistics, expected_ret, @@ -3304,7 +3305,7 @@ mod tests { prune_with_expr( // i > int64(0) - col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), DataType::Int32)), + col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), LogicalType::Int32)), &schema, &statistics, expected_ret, @@ -3312,7 +3313,7 @@ mod tests { prune_with_expr( // cast(i as int64) > int64(0) - cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), + cast(col("i"), LogicalType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), &schema, &statistics, expected_ret, @@ -3320,7 +3321,7 @@ mod tests { prune_with_expr( // try_cast(i as int64) > int64(0) - try_cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), + try_cast(col("i"), LogicalType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), &schema, &statistics, expected_ret, @@ -3328,7 +3329,7 @@ mod tests { prune_with_expr( // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` - Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) + Expr::Negative(Box::new(cast(col("i"), LogicalType::Int64))) .lt(lit(ScalarValue::Int64(Some(0)))), &schema, &statistics, @@ -3357,7 +3358,7 @@ mod tests { assert_eq!(result_right.to_string(), right_input.to_string()); // cast op lit - let left_input = cast(col("a"), DataType::Decimal128(20, 3)); + let left_input = cast(col("a"), LogicalType::Decimal128(20, 3)); let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Decimal128(Some(12), 20, 3)); let right_input = logical2physical(&right_input, &schema); @@ -3372,7 +3373,7 @@ mod tests { assert_eq!(result_right.to_string(), right_input.to_string()); // try_cast op lit - let left_input = try_cast(col("a"), DataType::Int64); + let left_input = try_cast(col("a"), LogicalType::Int64); let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int64(Some(12))); let right_input = logical2physical(&right_input, &schema); @@ -3391,7 +3392,7 @@ mod tests { // this cast is not supported let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let df_schema = DFSchema::try_from(schema.clone()).unwrap(); - let left_input = cast(col("a"), DataType::Int64); + let left_input = cast(col("a"), LogicalType::Int64); let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int64(Some(12))); let right_input = logical2physical(&right_input, &schema); diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 5b8501baaad8..e971c909a5c2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2232,6 +2232,8 @@ mod tests { use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::{assert_contains, DFSchemaRef, TableReference}; + use datafusion_common::logical_type::field::LogicalField; + use datafusion_common::logical_type::LogicalType; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; @@ -2474,12 +2476,11 @@ mod tests { let expected_error: &str = "Error during planning: \ Extension planner for NoOp created an ExecutionPlan with mismatched schema. \ LogicalPlan schema: \ - DFSchema { inner: Schema { fields: \ - [Field { name: \"a\", \ + DFSchema { inner: LogicalSchema { fields: \ + [LogicalField { name: \"a\", \ data_type: Int32, \ nullable: false, \ - dict_id: 0, \ - dict_is_ordered: false, metadata: {} }], \ + metadata: {} }], \ metadata: {} }, field_qualifiers: [None], \ functional_dependencies: FunctionalDependencies { deps: [] } }, \ ExecutionPlan schema: Schema { fields: \ @@ -2753,7 +2754,7 @@ mod tests { Self { schema: DFSchemaRef::new( DFSchema::from_unqualifed_fields( - vec![Field::new("a", DataType::Int32, false)].into(), + vec![LogicalField::new("a", LogicalType::Int32, false)].into(), HashMap::new(), ) .unwrap(), diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index e8550a79cb0e..1212b572df3e 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -206,7 +206,7 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { .schema() .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.name().to_string()) .collect(); assert_eq!(actual, expected); } diff --git a/datafusion/core/src/test/variable.rs b/datafusion/core/src/test/variable.rs index 38207b42cb7b..7cf5a7f3ec2b 100644 --- a/datafusion/core/src/test/variable.rs +++ b/datafusion/core/src/test/variable.rs @@ -20,7 +20,7 @@ use crate::error::Result; use crate::scalar::ScalarValue; use crate::variable::VarProvider; -use arrow::datatypes::DataType; +use datafusion_common::logical_type::LogicalType; /// System variable #[derive(Default, Debug)] @@ -40,8 +40,8 @@ impl VarProvider for SystemVar { Ok(ScalarValue::from(s)) } - fn get_type(&self, _: &[String]) -> Option { - Some(DataType::Utf8) + fn get_type(&self, _: &[String]) -> Option { + Some(LogicalType::Utf8) } } @@ -67,11 +67,11 @@ impl VarProvider for UserDefinedVar { } } - fn get_type(&self, var_names: &[String]) -> Option { + fn get_type(&self, var_names: &[String]) -> Option { if var_names[0] != "@integer" { - Some(DataType::Utf8) + Some(LogicalType::Utf8) } else { - Some(DataType::Int32) + Some(LogicalType::Int32) } } } diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 9f06ad9308ab..e8e33badc0af 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -42,6 +42,7 @@ use object_store::path::Path; use object_store::ObjectMeta; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; +use datafusion_common::logical_type::schema::LogicalSchema; /// a ParquetFile that has been created for testing. pub struct TestParquetFile { @@ -153,7 +154,7 @@ impl TestParquetFile { extensions: None, }); - let df_schema = self.schema.clone().to_dfschema_ref()?; + let df_schema = LogicalSchema::from(self.schema.as_ref().clone()).to_dfschema_ref()?; // run coercion on the filters to coerce types etc. let props = ExecutionProps::new(); diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 1c55c48fea40..b1d615405cac 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -31,6 +31,7 @@ use datafusion::prelude::*; use datafusion::assert_batches_eq; use datafusion_common::{DFSchema, ScalarValue}; +use datafusion_common::logical_type::LogicalType; use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; @@ -376,7 +377,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { // the arg2 parameter is a complex expr, but it can be evaluated to the literal value let alias_expr = Expr::Alias(Alias::new( - cast(lit(0.5), DataType::Float32), + cast(lit(0.5), LogicalType::Float32), None::<&str>, "arg_2".to_string(), )); @@ -949,7 +950,7 @@ async fn test_fn_substr() -> Result<()> { #[tokio::test] async fn test_cast() -> Result<()> { - let expr = cast(col("b"), DataType::Float64); + let expr = cast(col("b"), LogicalType::Float64); let expected = [ "+--------+", "| test.b |", @@ -1052,7 +1053,7 @@ async fn test_fn_decode() -> Result<()> { let expr = decode(encode(col("a"), lit("hex")), lit("hex")) // need to cast to utf8 otherwise the default display of binary array is hex // so it looks like nothing is done - .cast_to(&DataType::Utf8, &df_schema)?; + .cast_to(&LogicalType::Utf8, &df_schema)?; let expected = [ "+------------------------------------------------+", diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index c3bc2fcca2b5..d7228dcf3689 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -47,6 +47,7 @@ use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use datafusion::test_util::{parquet_test_data, populate_csv_partitions}; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion_common::{assert_contains, DataFusionError, ScalarValue, UnnestOptions}; +use datafusion_common::logical_type::LogicalType; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; @@ -256,7 +257,7 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { scalar_subquery(Arc::new( ctx.table("t2") .await? - .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? + .filter(out_ref_col(LogicalType::UInt32, "t1.a").eq(col("t2.a")))? .aggregate(vec![], vec![count(wildcard())])? .select(vec![col(count(wildcard()).to_string())])? .into_unoptimized_plan(), @@ -746,8 +747,8 @@ async fn join_with_alias_filter() -> Result<()> { // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2 let filter = Expr::eq( - col("t1.a") + lit(3i64).cast_to(&DataType::UInt32, &t1_schema)?, - col("t2.a") + lit(1i32).cast_to(&DataType::UInt32, &t2_schema)?, + col("t1.a") + lit(3i64).cast_to(&LogicalType::UInt32, &t1_schema)?, + col("t2.a") + lit(1i32).cast_to(&LogicalType::UInt32, &t2_schema)?, ) .alias("t1.b + 1 = t2.a + 2"); @@ -1622,7 +1623,7 @@ async fn consecutive_projection_same_schema() -> Result<()> { // Add `t` column full of nulls let df = df - .with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32)) + .with_column("t", cast(Expr::Literal(ScalarValue::Null), LogicalType::Int32)) .unwrap(); df.clone().show().await.unwrap(); @@ -1925,8 +1926,8 @@ impl VarProvider for HardcodedIntProvider { Ok(ScalarValue::Int64(Some(1234))) } - fn get_type(&self, _: &[String]) -> Option { - Some(DataType::Int64) + fn get_type(&self, _: &[String]) -> Option { + Some(LogicalType::Int64) } } diff --git a/datafusion/core/tests/expr_api/parse_sql_expr.rs b/datafusion/core/tests/expr_api/parse_sql_expr.rs index 991579b5a350..e082aaed322b 100644 --- a/datafusion/core/tests/expr_api/parse_sql_expr.rs +++ b/datafusion/core/tests/expr_api/parse_sql_expr.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::{DFSchemaRef, Result, ToDFSchema}; +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::LogicalType; +use datafusion_common::logical_type::schema::LogicalSchema; use datafusion_expr::Expr; use datafusion_sql::unparser::Unparser; @@ -27,10 +29,10 @@ use datafusion_sql::unparser::Unparser; /// b: Int32 /// s: Float32 fn schema() -> DFSchemaRef { - Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Float32, false), + LogicalSchema::new(vec![ + LogicalField::new("a", LogicalType::Int32, true), + LogicalField::new("b", LogicalType::Int32, false), + LogicalField::new("c", LogicalType::Float32, false), ]) .to_dfschema_ref() .unwrap() diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 9ce47153ba4a..38e1e8cd1551 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -37,6 +37,9 @@ use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use datafusion_optimizer::{OptimizerContext, OptimizerRule}; use std::sync::Arc; +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::LogicalType; +use datafusion_common::logical_type::schema::LogicalSchema; /// In order to simplify expressions, DataFusion must have information /// about the expressions. @@ -56,7 +59,7 @@ impl SimplifyInfo for MyInfo { fn is_boolean_type(&self, expr: &Expr) -> Result { Ok(matches!( expr.get_type(self.schema.as_ref())?, - DataType::Boolean + LogicalType::Boolean )) } @@ -68,7 +71,7 @@ impl SimplifyInfo for MyInfo { &self.execution_props } - fn get_data_type(&self, expr: &Expr) -> Result { + fn get_data_type(&self, expr: &Expr) -> Result { expr.get_type(self.schema.as_ref()) } } @@ -88,10 +91,10 @@ impl From for MyInfo { /// b: Int32 /// s: Utf8 fn schema() -> DFSchemaRef { - Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, false), - Field::new("s", DataType::Utf8, false), + LogicalSchema::new(vec![ + LogicalField::new("a", LogicalType::Int32, true), + LogicalField::new("b", LogicalType::Int32, false), + LogicalField::new("s", LogicalType::Utf8, false), ]) .to_dfschema_ref() .unwrap() @@ -190,7 +193,7 @@ fn make_udf_add(volatility: Volatility) -> Arc { } fn cast_to_int64_expr(expr: Expr) -> Expr { - Expr::Cast(Cast::new(expr.into(), DataType::Int64)) + Expr::Cast(Cast::new(expr.into(), LogicalType::Int64)) } fn to_timestamp_expr(arg: impl Into) -> Expr { @@ -281,7 +284,7 @@ fn select_date_plus_interval() -> Result<()> { let schema = table_scan.schema(); let date_plus_interval_expr = to_timestamp_expr(ts_string) - .cast_to(&DataType::Date32, schema)? + .cast_to(&LogicalType::Date32, schema)? + Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 123, milliseconds: 0, @@ -483,15 +486,15 @@ fn multiple_now() -> Result<()> { // ------------------------------ fn expr_test_schema() -> DFSchemaRef { - Schema::new(vec![ - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::Boolean, true), - Field::new("c3", DataType::Int64, true), - Field::new("c4", DataType::UInt32, true), - Field::new("c1_non_null", DataType::Utf8, false), - Field::new("c2_non_null", DataType::Boolean, false), - Field::new("c3_non_null", DataType::Int64, false), - Field::new("c4_non_null", DataType::UInt32, false), + LogicalSchema::new(vec![ + LogicalField::new("c1", LogicalType::Utf8, true), + LogicalField::new("c2", LogicalType::Boolean, true), + LogicalField::new("c3", LogicalType::Int64, true), + LogicalField::new("c4", LogicalType::UInt32, true), + LogicalField::new("c1_non_null", LogicalType::Utf8, false), + LogicalField::new("c2_non_null", LogicalType::Boolean, false), + LogicalField::new("c3_non_null", LogicalType::Int64, false), + LogicalField::new("c4_non_null", LogicalType::UInt32, false), ]) .to_dfschema_ref() .unwrap() @@ -688,8 +691,8 @@ fn test_simplify_concat() { #[test] fn test_simplify_cycles() { // cast(now() as int64) < cast(to_timestamp(0) as int64) + i64::MAX - let expr = cast(now(), DataType::Int64) - .lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX)); + let expr = cast(now(), LogicalType::Int64) + .lt(cast(to_timestamp(vec![lit(0)]), LogicalType::Int64) + lit(i64::MAX)); let expected = lit(true); test_simplify_with_cycle_count(expr, expected, 3); } diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index 39f745cd3309..40c91c4bd4ff 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -25,6 +25,8 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use arrow_schema::{Fields, SchemaBuilder}; use datafusion_common::config::ConfigOptions; +use datafusion_common::logical_type::schema::LogicalSchemaRef; +use datafusion_common::logical_type::LogicalType; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{plan_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; @@ -203,7 +205,7 @@ impl ContextProvider for MyContextProvider { None } - fn get_variable_type(&self, _variable_names: &[String]) -> Option { + fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } @@ -259,7 +261,7 @@ fn test_nested_schema_nullability() { let dfschema = DFSchema::from_field_specific_qualified_schema( vec![Some("table_name".into()), None], - &Arc::new(schema), + &LogicalSchemaRef::new(schema.into()), ) .unwrap(); diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 15efd4bcd9dd..cfe74b12e25a 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -35,6 +35,7 @@ use datafusion_physical_expr::create_physical_expr; use futures::StreamExt; use object_store::path::Path; use object_store::ObjectMeta; +use datafusion_common::logical_type::schema::LogicalSchema; async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { let object_store_url = ObjectStoreUrl::local_filesystem(); @@ -66,7 +67,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { extensions: None, }; - let df_schema = schema.clone().to_dfschema().unwrap(); + let df_schema = LogicalSchema::from(schema.as_ref().clone()).to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); let predicate = create_physical_expr(&filter, &df_schema, &execution_props).unwrap(); diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 5e3c44c039ab..2603290442b6 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -39,6 +39,8 @@ use datafusion_expr::{ use datafusion_functions_array::range::range_udf; use parking_lot::Mutex; use sqlparser::ast::Ident; +use datafusion_common::logical_type::extension::ExtensionType; +use datafusion_common::logical_type::LogicalType; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -518,14 +520,14 @@ impl ScalarUDFImpl for CastToI64UDF { // SimplifyInfo so we have to replicate some of the casting logic here. let source_type = info.get_data_type(&arg)?; - let new_expr = if source_type == DataType::Int64 { + let new_expr = if source_type == LogicalType::Int64 { // the argument's data type is already the correct type arg } else { // need to use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { expr: Box::new(arg), - data_type: DataType::Int64, + data_type: LogicalType::Int64, }) }; // return the newly written argument to DataFusion @@ -645,7 +647,7 @@ impl ScalarUDFImpl for TakeUDF { ); }; - arg_exprs.get(take_idx).unwrap().get_type(schema) + arg_exprs.get(take_idx).unwrap().get_type(schema).map(|t| t.physical_type()) } // The actual implementation @@ -687,8 +689,8 @@ async fn verify_udf_return_type() -> Result<()> { // The output schema should be // * type of column smallint_col (int32) // * type of column double_col (float64) - assert_eq!(schema.field(0).data_type(), &DataType::Int32); - assert_eq!(schema.field(1).data_type(), &DataType::Float64); + assert_eq!(schema.field(0).data_type(), &LogicalType::Int32); + assert_eq!(schema.field(1).data_type(), &LogicalType::Float64); let expected = [ "+-------+-------+", @@ -835,13 +837,16 @@ impl TryFrom for ScalarFunctionWrapper { .expect("Expression has to be defined!"), return_type: definition .return_type - .expect("Return type has to be defined!"), + .expect("Return type has to be defined!") + .physical_type(), + // TODO(@notfilippo): avoid conversion to physical type signature: Signature::exact( definition .args .unwrap_or_default() .into_iter() - .map(|a| a.data_type) + // TODO(@notfilippo): avoid conversion to physical type + .map(|a| a.data_type.physical_type()) .collect(), definition .params @@ -990,10 +995,10 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( value: "name".into(), quote_style: None, }), - data_type: DataType::Utf8, + data_type: LogicalType::Utf8, default_expr: None, }]), - return_type: Some(DataType::Int32), + return_type: Some(LogicalType::Int32), params: CreateFunctionBody { language: Some(Ident { value: "plrust".into(), diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 7a2bf4b6c44a..176a2002501a 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -18,9 +18,9 @@ //! Conditional expressions use crate::expr::Case; use crate::{expr_schema::ExprSchemable, Expr}; -use arrow::datatypes::DataType; use datafusion_common::{plan_err, DFSchema, Result}; use std::collections::HashSet; +use datafusion_common::logical_type::LogicalType; /// Helper struct for building [Expr::Case] pub struct CaseBuilder { @@ -70,18 +70,18 @@ impl CaseBuilder { then_expr.push(e.as_ref().to_owned()); } - let then_types: Vec = then_expr + let then_types: Vec = then_expr .iter() .map(|e| match e { Expr::Literal(_) => e.get_type(&DFSchema::empty()), - _ => Ok(DataType::Null), + _ => Ok(LogicalType::Null), }) .collect::>>()?; - if then_types.contains(&DataType::Null) { + if then_types.contains(&LogicalType::Null) { // cannot verify types until execution type } else { - let unique_types: HashSet<&DataType> = then_types.iter().collect(); + let unique_types: HashSet<&LogicalType> = then_types.iter().collect(); if unique_types.len() != 1 { return plan_err!( "CASE expression 'then' values had multiple data types: {unique_types:?}" diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 846b627b2242..0c084c8cadb7 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -33,7 +33,7 @@ use crate::{ }; use crate::{window_frame, Volatility}; -use arrow::datatypes::{DataType, FieldRef}; +use arrow::datatypes::DataType; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -41,6 +41,8 @@ use datafusion_common::{ internal_err, plan_err, Column, DFSchema, Result, ScalarValue, TableReference, }; use sqlparser::ast::NullTreatment; +use datafusion_common::logical_type::field::LogicalFieldRef; +use datafusion_common::logical_type::LogicalType; /// Represents logical expressions such as `A + 1`, or `CAST(c1 AS int)`. /// @@ -153,7 +155,7 @@ use sqlparser::ast::NullTreatment; /// Field::new("c2", DataType::Float64, false), /// ]); /// // DFSchema is a an Arrow schema with optional relation name -/// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema) +/// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema.into()) /// .unwrap(); /// /// // Form Vec with an expression for each column in the schema @@ -223,7 +225,7 @@ pub enum Expr { /// A named reference to a qualified filed in a schema. Column(Column), /// A named reference to a variable in a registry. - ScalarVariable(DataType, Vec), + ScalarVariable(LogicalType, Vec), /// A constant value. Literal(ScalarValue), /// A binary expression such as "age > 21" @@ -317,7 +319,7 @@ pub enum Expr { Placeholder(Placeholder), /// A place holder which hold a reference to a qualified field /// in the outer query, used for correlated sub queries. - OuterReferenceColumn(DataType, Column), + OuterReferenceColumn(LogicalType, Column), /// Unnest expression Unnest(Unnest), } @@ -339,8 +341,8 @@ impl From for Expr { /// useful for creating [`Expr`] from a [`DFSchema`]. /// /// See example on [`Expr`] -impl<'a> From<(Option<&'a TableReference>, &'a FieldRef)> for Expr { - fn from(value: (Option<&'a TableReference>, &'a FieldRef)) -> Self { +impl<'a> From<(Option<&'a TableReference>, &'a LogicalFieldRef)> for Expr { + fn from(value: (Option<&'a TableReference>, &'a LogicalFieldRef)) -> Self { Expr::from(Column::from(value)) } } @@ -558,13 +560,13 @@ pub enum GetFieldAccess { pub struct Cast { /// The expression being cast pub expr: Box, - /// The `DataType` the expression will yield - pub data_type: DataType, + /// The `LogicalType` the expression will yield + pub data_type: LogicalType, } impl Cast { /// Create a new Cast expression - pub fn new(expr: Box, data_type: DataType) -> Self { + pub fn new(expr: Box, data_type: LogicalType) -> Self { Self { expr, data_type } } } @@ -574,13 +576,13 @@ impl Cast { pub struct TryCast { /// The expression being cast pub expr: Box, - /// The `DataType` the expression will yield - pub data_type: DataType, + /// The `LogicalType` the expression will yield + pub data_type: LogicalType, } impl TryCast { /// Create a new TryCast expression - pub fn new(expr: Box, data_type: DataType) -> Self { + pub fn new(expr: Box, data_type: LogicalType) -> Self { Self { expr, data_type } } } @@ -926,12 +928,12 @@ pub struct Placeholder { /// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo"`) pub id: String, /// The type the parameter will be filled in with - pub data_type: Option, + pub data_type: Option, } impl Placeholder { /// Create a new Placeholder expression - pub fn new(id: String, data_type: Option) -> Self { + pub fn new(id: String, data_type: Option) -> Self { Self { id, data_type } } } @@ -2060,7 +2062,7 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { Expr::InSubquery(InSubquery { negated: true, .. }) => w.write_str("NOT IN")?, Expr::InSubquery(InSubquery { negated: false, .. }) => w.write_str("IN")?, Expr::ScalarSubquery(subquery) => { - w.write_str(subquery.subquery.schema().field(0).name().as_str())?; + w.write_str(subquery.subquery.schema().field(0).name())?; } Expr::Unnest(Unnest { expr }) => { w.write_str("unnest(")?; @@ -2222,7 +2224,7 @@ mod test { fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), - data_type: DataType::Utf8, + data_type: LogicalType::Utf8, }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, expr.canonical_name()); @@ -2255,7 +2257,7 @@ mod test { fn test_collect_expr() -> Result<()> { // single column { - let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); + let expr = &Expr::Cast(Cast::new(Box::new(col("a")), LogicalType::Float64)); let columns = expr.column_refs(); assert_eq!(1, columns.len()); assert!(columns.contains(&Column::from_name("a"))); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8b0213fd52fd..2db73e5c413b 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -40,6 +40,7 @@ use std::any::Any; use std::fmt::Debug; use std::ops::Not; use std::sync::Arc; +use datafusion_common::logical_type::LogicalType; /// Create a column expression based on a qualified or unqualified column name. Will /// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase). @@ -62,7 +63,7 @@ pub fn col(ident: impl Into) -> Expr { /// Create an out reference column which hold a reference that has been resolved to a field /// outside of the current plan. -pub fn out_ref_col(dt: DataType, ident: impl Into) -> Expr { +pub fn out_ref_col(dt: LogicalType, ident: impl Into) -> Expr { Expr::OuterReferenceColumn(dt, ident.into()) } @@ -308,12 +309,12 @@ pub fn rollup(exprs: Vec) -> Expr { } /// Create a cast expression -pub fn cast(expr: Expr, data_type: DataType) -> Expr { +pub fn cast(expr: Expr, data_type: LogicalType) -> Expr { Expr::Cast(Cast::new(Box::new(expr), data_type)) } /// Create a try cast expression -pub fn try_cast(expr: Expr, data_type: DataType) -> Expr { +pub fn try_cast(expr: Expr, data_type: LogicalType) -> Expr { Expr::TryCast(TryCast::new(Box::new(expr), data_type)) } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 1441374bdba3..ea3d6f353aa9 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -286,7 +286,9 @@ mod test { use super::*; use crate::expr::Sort; use crate::{col, lit, Cast}; - use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::logical_type::field::LogicalField; + use datafusion_common::logical_type::LogicalType; + use datafusion_common::logical_type::schema::LogicalSchema; use datafusion_common::ScalarValue; #[derive(Default)] @@ -407,10 +409,10 @@ mod test { ) -> DFSchema { let fields = fields .iter() - .map(|f| Arc::new(Field::new(f.to_string(), DataType::Int8, false))) + .map(|f| Arc::new(LogicalField::new(f.to_string(), LogicalType::Int8, false))) .collect::>(); - let schema = Arc::new(Schema::new(fields)); - DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap() + let schema = Arc::new(LogicalSchema::new(fields)); + DFSchema::from_field_specific_qualified_schema(qualifiers, &schema.into()).unwrap() } #[test] @@ -440,7 +442,7 @@ mod test { // cast data types test_rewrite( col("a"), - Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)), + Expr::Cast(Cast::new(Box::new(col("a")), LogicalType::Int32)), ); // change literal type from i32 to i64 diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 4b56ca3d1c2e..8084995bdf63 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -154,7 +154,7 @@ mod test { use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; - + use datafusion_common::logical_type::LogicalType; use crate::{ cast, col, lit, logical_plan::builder::LogicalTableSource, min, test::function_stub::avg, try_cast, LogicalPlanBuilder, @@ -270,13 +270,13 @@ mod test { let cases = vec![ TestCase { desc: "Cast is preserved by rewrite_sort_cols_by_aggs", - input: sort(cast(col("c2"), DataType::Int64)), - expected: sort(cast(col("c2").alias("c2"), DataType::Int64)), + input: sort(cast(col("c2"), LogicalType::Int64)), + expected: sort(cast(col("c2").alias("c2"), LogicalType::Int64)), }, TestCase { desc: "TryCast is preserved by rewrite_sort_cols_by_aggs", - input: sort(try_cast(col("c2"), DataType::Int64)), - expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)), + input: sort(try_cast(col("c2"), LogicalType::Int64)), + expected: sort(try_cast(col("c2").alias("c2"), LogicalType::Int64)), }, ]; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index d5a04ad4ae1f..8106be838676 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -26,18 +26,20 @@ use crate::type_coercion::functions::{ }; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field}; use datafusion_common::{ internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, ExprSchema, Result, TableReference, }; use std::collections::HashMap; use std::sync::Arc; +use datafusion_common::logical_type::extension::ExtensionType; +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::LogicalType; /// trait to allow expr to typable with respect to a schema pub trait ExprSchemable { /// given a schema, return the type of the expr - fn get_type(&self, schema: &dyn ExprSchema) -> Result; + fn get_type(&self, schema: &dyn ExprSchema) -> Result; /// given a schema, return the nullability of the expr fn nullable(&self, input_schema: &dyn ExprSchema) -> Result; @@ -49,14 +51,14 @@ pub trait ExprSchemable { fn to_field( &self, input_schema: &dyn ExprSchema, - ) -> Result<(Option, Arc)>; + ) -> Result<(Option, Arc)>; /// cast to a type with respect to a schema - fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; + fn cast_to(self, cast_to_type: &LogicalType, schema: &dyn ExprSchema) -> Result; /// given a schema, return the type and nullability of the expr fn data_type_and_nullable(&self, schema: &dyn ExprSchema) - -> Result<(DataType, bool)>; + -> Result<(LogicalType, bool)>; } impl ExprSchemable for Expr { @@ -98,7 +100,7 @@ impl ExprSchemable for Expr { /// expression refers to a column that does not exist in the /// schema, or when the expression is incorrectly typed /// (e.g. `[utf8] + [bool]`). - fn get_type(&self, schema: &dyn ExprSchema) -> Result { + fn get_type(&self, schema: &dyn ExprSchema) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { @@ -111,7 +113,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), - Expr::Literal(l) => Ok(l.data_type()), + Expr::Literal(l) => Ok(l.data_type().into()), Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), @@ -119,13 +121,13 @@ impl ExprSchemable for Expr { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list match arg_data_type{ - DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) =>{ + LogicalType::List(field) | LogicalType::LargeList(field) | LogicalType::FixedSizeList(field, _) =>{ Ok(field.data_type().clone()) } - DataType::Struct(_) => { + LogicalType::Struct(_) => { Ok(arg_data_type) } - DataType::Null => { + LogicalType::Null => { not_impl_err!("unnest() does not support null yet") } _ => { @@ -138,28 +140,40 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| { - plan_datafusion_err!( - "{} {}", - err, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_data_types, - ) - ) - })?; - // perform additional function arguments validation (due to limited - // expressiveness of `TypeSignature`), then infer return type - Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?) + // TODO(@notfilippo): not convert to DataType + let arg_data_types = arg_data_types + .into_iter() + .map(|e| e.physical_type()) + .collect::>(); + + // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| { + plan_datafusion_err!( + "{} {}", + err, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &arg_data_types, + ) + ) + })?; + + // perform additional function arguments validation (due to limited + // expressiveness of `TypeSignature`), then infer return type + Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?.into()) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; + // TODO(@notfilippo): not convert to DataType + let data_types = data_types + .into_iter() + .map(|e| e.physical_type()) + .collect::>(); let nullability = args .iter() .map(|e| e.nullable(schema)) @@ -177,10 +191,10 @@ impl ExprSchemable for Expr { ) ) })?; - Ok(fun.return_type(&new_types, &nullability)?) + Ok(fun.return_type(&new_types, &nullability)?.into()) } _ => { - fun.return_type(&data_types, &nullability) + Ok(fun.return_type(&data_types, &nullability)?.into()) } } } @@ -193,9 +207,14 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.nullable(schema)) .collect::>>()?; + // TODO(@notfilippo): not convert to DataType + let data_types = data_types + .into_iter() + .map(|e| e.physical_type()) + .collect::>(); match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { - fun.return_type(&data_types, &nullability) + Ok(fun.return_type(&data_types, &nullability)?.into()) } AggregateFunctionDefinition::UDF(fun) => { let new_types = data_types_with_aggregate_udf(&data_types, fun).map_err(|err| { @@ -209,7 +228,7 @@ impl ExprSchemable for Expr { ) ) })?; - Ok(fun.return_type(&new_types)?) + Ok(fun.return_type(&new_types)?.into()) } } } @@ -225,7 +244,7 @@ impl ExprSchemable for Expr { | Expr::IsUnknown(_) | Expr::IsNotTrue(_) | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) => Ok(DataType::Boolean), + | Expr::IsNotUnknown(_) => Ok(LogicalType::Boolean), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).data_type().clone()) } @@ -233,8 +252,9 @@ impl ExprSchemable for Expr { ref left, ref right, ref op, - }) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?), - Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), + // TODO(@notfilippo): do not convert to physical type + }) => Ok(get_result_type(&left.get_type(schema)?.physical_type(), op, &right.get_type(schema)?.physical_type())?.into()), + Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(LogicalType::Boolean), Expr::Placeholder(Placeholder { data_type, .. }) => { data_type.clone().ok_or_else(|| { plan_datafusion_err!("Placeholder type could not be resolved. Make sure that the placeholder is bound to a concrete type, e.g. by providing parameter values.") @@ -244,12 +264,12 @@ impl ExprSchemable for Expr { // Wildcard do not really have a type and do not appear in projections match qualifier { Some(_) => internal_err!("QualifiedWildcard expressions are not valid in a logical query plan"), - None => Ok(DataType::Null) + None => Ok(LogicalType::Null) } } Expr::GroupingSet(_) => { // grouping sets do not really have a type and do not appear in projections - Ok(DataType::Null) + Ok(LogicalType::Null) } } } @@ -392,7 +412,7 @@ impl ExprSchemable for Expr { fn data_type_and_nullable( &self, schema: &dyn ExprSchema, - ) -> Result<(DataType, bool)> { + ) -> Result<(LogicalType, bool)> { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { @@ -411,7 +431,7 @@ impl ExprSchemable for Expr { .map(|(d, n)| (d.clone(), n)), Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), - Expr::Literal(l) => Ok((l.data_type(), l.is_null())), + Expr::Literal(l) => Ok((l.data_type().into(), l.is_null())), Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -420,7 +440,7 @@ impl ExprSchemable for Expr { | Expr::IsNotTrue(_) | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) - | Expr::Exists { .. } => Ok((DataType::Boolean, false)), + | Expr::Exists { .. } => Ok((LogicalType::Boolean, false)), Expr::ScalarSubquery(subquery) => Ok(( subquery.subquery.schema().field(0).data_type().clone(), subquery.subquery.schema().field(0).is_nullable(), @@ -432,7 +452,8 @@ impl ExprSchemable for Expr { }) => { let left = left.data_type_and_nullable(schema)?; let right = right.data_type_and_nullable(schema)?; - Ok((get_result_type(&left.0, op, &right.0)?, left.1 || right.1)) + // TODO(@notfilippo): do not convert to physical type + Ok((get_result_type(&left.0.physical_type(), op, &right.0.physical_type())?.into(), left.1 || right.1)) } _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), } @@ -445,13 +466,13 @@ impl ExprSchemable for Expr { fn to_field( &self, input_schema: &dyn ExprSchema, - ) -> Result<(Option, Arc)> { + ) -> Result<(Option, Arc)> { match self { Expr::Column(c) => { let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; Ok(( c.relation.clone(), - Field::new(&c.name, data_type, nullable) + LogicalField::new(&c.name, data_type, nullable) .with_metadata(self.metadata(input_schema)?) .into(), )) @@ -460,7 +481,7 @@ impl ExprSchemable for Expr { let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; Ok(( relation.clone(), - Field::new(name, data_type, nullable) + LogicalField::new(name, data_type, nullable) .with_metadata(self.metadata(input_schema)?) .into(), )) @@ -469,7 +490,7 @@ impl ExprSchemable for Expr { let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; Ok(( None, - Field::new(self.display_name()?, data_type, nullable) + LogicalField::new(self.display_name()?, data_type, nullable) .with_metadata(self.metadata(input_schema)?) .into(), )) @@ -483,7 +504,7 @@ impl ExprSchemable for Expr { /// /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. - fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result { + fn cast_to(self, cast_to_type: &LogicalType, schema: &dyn ExprSchema) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { return Ok(self); @@ -493,7 +514,8 @@ impl ExprSchemable for Expr { // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? - if can_cast_types(&this_type, cast_to_type) { + // TODO(@notfilippo): The basis for whether cast can be executed should be the logical type + if can_cast_types(&this_type.physical_type(), &cast_to_type.physical_type()) { match self { Expr::ScalarSubquery(subquery) => { Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) @@ -507,7 +529,7 @@ impl ExprSchemable for Expr { } /// cast subquery in InSubquery/ScalarSubquery to a given type. -pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { +pub fn cast_subquery(subquery: Subquery, cast_to_type: &LogicalType) -> Result { if subquery.subquery.schema().field(0).data_type() == cast_to_type { return Ok(subquery); } @@ -574,7 +596,7 @@ mod tests { fn test_between_nullability() { let get_schema = |nullable| { MockExprSchema::new() - .with_data_type(DataType::Int32) + .with_data_type(LogicalType::Int32) .with_nullable(nullable) }; @@ -598,7 +620,7 @@ mod tests { fn test_inlist_nullability() { let get_schema = |nullable| { MockExprSchema::new() - .with_data_type(DataType::Int32) + .with_data_type(LogicalType::Int32) .with_nullable(nullable) }; @@ -623,7 +645,7 @@ mod tests { fn test_like_nullability() { let get_schema = |nullable| { MockExprSchema::new() - .with_data_type(DataType::Utf8) + .with_data_type(LogicalType::Utf8) .with_nullable(nullable) }; @@ -639,8 +661,8 @@ mod tests { fn expr_schema_data_type() { let expr = col("foo"); assert_eq!( - DataType::Utf8, - expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8)) + LogicalType::Utf8, + expr.get_type(&MockExprSchema::new().with_data_type(LogicalType::Utf8)) .unwrap() ); } @@ -651,7 +673,7 @@ mod tests { meta.insert("bar".to_string(), "buzz".to_string()); let expr = col("foo"); let schema = MockExprSchema::new() - .with_data_type(DataType::Int32) + .with_data_type(LogicalType::Int32) .with_metadata(meta.clone()); // col and alias should be metadata-preserving @@ -662,14 +684,14 @@ mod tests { assert_eq!( HashMap::new(), expr.clone() - .cast_to(&DataType::Int64, &schema) + .cast_to(&LogicalType::Int64, &schema) .unwrap() .metadata(&schema) .unwrap() ); let schema = DFSchema::from_unqualifed_fields( - vec![Field::new("foo", DataType::Int32, true).with_metadata(meta.clone())] + vec![LogicalField::new("foo", LogicalType::Int32, true).with_metadata(meta.clone())] .into(), HashMap::new(), ) @@ -682,7 +704,7 @@ mod tests { #[derive(Debug)] struct MockExprSchema { nullable: bool, - data_type: DataType, + data_type: LogicalType, error_on_nullable: bool, metadata: HashMap, } @@ -691,7 +713,7 @@ mod tests { fn new() -> Self { Self { nullable: false, - data_type: DataType::Null, + data_type: LogicalType::Null, error_on_nullable: false, metadata: HashMap::new(), } @@ -702,7 +724,7 @@ mod tests { self } - fn with_data_type(mut self, data_type: DataType) -> Self { + fn with_data_type(mut self, data_type: LogicalType) -> Self { self.data_type = data_type; self } @@ -727,7 +749,7 @@ mod tests { } } - fn data_type(&self, _col: &Column) -> Result<&DataType> { + fn data_type(&self, _col: &Column) -> Result<&LogicalType> { Ok(&self.data_type) } @@ -735,7 +757,7 @@ mod tests { Ok(&self.metadata) } - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + fn data_type_and_nullable(&self, col: &Column) -> Result<(&LogicalType, bool)> { Ok((self.data_type(col)?, self.nullable(col)?)) } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index f87151efd88b..3b5be7a839d7 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -48,14 +48,15 @@ use crate::{ WriteOp, }; -use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; -use datafusion_common::{ - get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, - plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - TableReference, ToDFSchema, UnnestOptions, -}; +use datafusion_common::{get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, UnnestOptions, ToDFSchema}; +use datafusion_common::logical_type::extension::ExtensionType; +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::fields::LogicalFields; +use datafusion_common::logical_type::LogicalType; +use datafusion_common::logical_type::schema::LogicalSchema; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -182,26 +183,26 @@ impl LogicalPlanBuilder { } let empty_schema = DFSchema::empty(); - let mut field_types: Vec = Vec::with_capacity(n_cols); + let mut field_types: Vec = Vec::with_capacity(n_cols); for j in 0..n_cols { - let mut common_type: Option = None; + let mut common_type: Option = None; for (i, row) in values.iter().enumerate() { let value = &row[j]; let data_type = value.get_type(&empty_schema)?; - if data_type == DataType::Null { + if data_type == LogicalType::Null { continue; } if let Some(prev_type) = common_type { // get common type of each column values. - let Some(new_type) = values_coercion(&data_type, &prev_type) else { + let Some(new_type) = values_coercion(&data_type.physical_type(), &prev_type.physical_type()) else { return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_type} but found {data_type}"); }; - common_type = Some(new_type); + common_type = Some(new_type.into()); } else { common_type = Some(data_type.clone()); } } - field_types.push(common_type.unwrap_or(DataType::Utf8)); + field_types.push(common_type.unwrap_or(LogicalType::Utf8)); } // wrap cast if data type is not same as common type. for row in &mut values { @@ -220,7 +221,7 @@ impl LogicalPlanBuilder { .map(|(j, data_type)| { // naming is following convention https://www.postgresql.org/docs/current/queries-values.html let name = &format!("column{}", j + 1); - Field::new(name, data_type.clone(), true) + LogicalField::new(name, data_type.clone(), true) }) .collect::>(); let dfschema = DFSchema::from_unqualifed_fields(fields.into(), HashMap::new())?; @@ -289,7 +290,7 @@ impl LogicalPlanBuilder { pub fn insert_into( input: LogicalPlan, table_name: impl Into, - table_schema: &Schema, + table_schema: &LogicalSchema, overwrite: bool, ) -> Result { let table_schema = table_schema.clone().to_dfschema_ref()?; @@ -383,7 +384,7 @@ impl LogicalPlanBuilder { } /// Make a builder for a prepare logical plan from the builder's plan - pub fn prepare(self, name: String, data_types: Vec) -> Result { + pub fn prepare(self, name: String, data_types: Vec) -> Result { Ok(Self::from(LogicalPlan::Prepare(Prepare { name, data_types, @@ -1181,7 +1182,7 @@ impl From> for LogicalPlanBuilder { } } -pub fn change_redundant_column(fields: &Fields) -> Vec { +pub fn change_redundant_column(fields: &LogicalFields) -> Vec { let mut name_map = HashMap::new(); fields .into_iter() @@ -1190,7 +1191,7 @@ pub fn change_redundant_column(fields: &Fields) -> Vec { *counter += 1; if *counter > 1 { let new_name = format!("{}:{}", field.name(), *counter - 1); - Field::new(new_name, field.data_type().clone(), field.is_nullable()) + LogicalField::new(new_name, field.data_type().clone(), field.is_nullable()) } else { field.as_ref().clone() } @@ -1205,8 +1206,8 @@ pub fn build_join_schema( join_type: &JoinType, ) -> Result { fn nullify_fields<'a>( - fields: impl Iterator, &'a Arc)>, - ) -> Vec<(Option, Arc)> { + fields: impl Iterator, &'a Arc)>, + ) -> Vec<(Option, Arc)> { fields .map(|(q, f)| { // TODO: find a good way to do that @@ -1219,7 +1220,7 @@ pub fn build_join_schema( let right_fields = right.iter(); let left_fields = left.iter(); - let qualified_fields: Vec<(Option, Arc)> = match join_type { + let qualified_fields: Vec<(Option, Arc)> = match join_type { JoinType::Inner => { // left then right let left_fields = left_fields @@ -1377,9 +1378,10 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result Result) -> Result { // - Struct(field1, field2) returns ["a.field1","a.field2"] pub fn get_unnested_columns( col_name: &String, - data_type: &DataType, -) -> Result)>> { + data_type: &LogicalType, +) -> Result)>> { let mut qualified_columns = Vec::with_capacity(1); match data_type { - DataType::List(field) - | DataType::FixedSizeList(field, _) - | DataType::LargeList(field) => { - let new_field = Arc::new(Field::new( + LogicalType::List(field) + | LogicalType::FixedSizeList(field, _) + | LogicalType::LargeList(field) => { + let new_field = Arc::new(LogicalField::new( col_name.clone(), field.data_type().clone(), // Unnesting may produce NULLs even if the list is not null. @@ -1621,7 +1623,7 @@ pub fn get_unnested_columns( // let column = Column::from((None, &new_field)); qualified_columns.push((column, new_field)); } - DataType::Struct(fields) => { + LogicalType::Struct(fields) => { qualified_columns.extend(fields.iter().map(|f| { let new_name = format!("{}.{}", col_name, f.name()); let column = Column::from_name(&new_name); @@ -1670,10 +1672,10 @@ pub fn unnest_with_options( original_field.data_type(), )?; match original_field.data_type() { - DataType::List(_) - | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) => list_columns.push(index), - DataType::Struct(_) => struct_columns.push(index), + LogicalType::List(_) + | LogicalType::FixedSizeList(_, _) + | LogicalType::LargeList(_) => list_columns.push(index), + LogicalType::Struct(_) => struct_columns.push(index), _ => { panic!( "not reachable, should be caught by get_unnested_columns" @@ -1685,7 +1687,7 @@ pub fn unnest_with_options( .extend(std::iter::repeat(index).take(flatten_columns.len())); Ok(flatten_columns .iter() - .map(|col: &(Column, Arc)| { + .map(|col: &(Column, Arc)| { (col.0.relation.to_owned(), col.1.to_owned()) }) .collect()) @@ -1720,6 +1722,7 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { + use arrow::datatypes::{DataType, Field, Fields}; use super::*; use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; @@ -1752,7 +1755,7 @@ mod tests { .unwrap(); let expected = DFSchema::try_from_qualified_schema( TableReference::bare("employee_csv"), - &schema, + &schema.clone().into(), ) .unwrap(); assert_eq!(&expected, plan.schema().as_ref()); @@ -2109,7 +2112,7 @@ mod tests { // Check unnested field is a scalar let field = plan.schema().field_with_name(None, "strings").unwrap(); - assert_eq!(&DataType::Utf8, field.data_type()); + assert_eq!(&LogicalType::Utf8, field.data_type()); // Unnesting the singular struct column result into 2 new columns for each subfield let plan = nested_table_scan("test_table")? @@ -2127,7 +2130,7 @@ mod tests { .schema() .field_with_name(None, &format!("struct_singular.{}", field_name)) .unwrap(); - assert_eq!(&DataType::UInt32, field.data_type()); + assert_eq!(&LogicalType::UInt32, field.data_type()); } // Unnesting multiple fields in separate plans @@ -2146,7 +2149,7 @@ mod tests { // Check unnested struct list field should be a struct. let field = plan.schema().field_with_name(None, "structs").unwrap(); - assert!(matches!(field.data_type(), DataType::Struct(_))); + assert!(matches!(field.data_type(), LogicalType::Struct(_))); // Unnesting multiple fields at the same time let cols = vec!["strings", "structs", "struct_singular"] @@ -2220,23 +2223,23 @@ mod tests { #[test] fn test_change_redundant_column() -> Result<()> { - let t1_field_1 = Field::new("a", DataType::Int32, false); - let t2_field_1 = Field::new("a", DataType::Int32, false); - let t2_field_3 = Field::new("a", DataType::Int32, false); - let t1_field_2 = Field::new("b", DataType::Int32, false); - let t2_field_2 = Field::new("b", DataType::Int32, false); + let t1_field_1 = LogicalField::new("a", LogicalType::Int32, false); + let t2_field_1 = LogicalField::new("a", LogicalType::Int32, false); + let t2_field_3 = LogicalField::new("a", LogicalType::Int32, false); + let t1_field_2 = LogicalField::new("b", LogicalType::Int32, false); + let t2_field_2 = LogicalField::new("b", LogicalType::Int32, false); let field_vec = vec![t1_field_1, t2_field_1, t1_field_2, t2_field_2, t2_field_3]; - let remove_redundant = change_redundant_column(&Fields::from(field_vec)); + let remove_redundant = change_redundant_column(&LogicalFields::from(field_vec)); assert_eq!( remove_redundant, vec![ - Field::new("a", DataType::Int32, false), - Field::new("a:1", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("b:1", DataType::Int32, false), - Field::new("a:2", DataType::Int32, false), + LogicalField::new("a", LogicalType::Int32, false), + LogicalField::new("a:1", LogicalType::Int32, false), + LogicalField::new("b", LogicalType::Int32, false), + LogicalField::new("b:1", LogicalType::Int32, false), + LogicalField::new("a:2", LogicalType::Int32, false), ] ); Ok(()) diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 45ddbafecfd7..8b938a808b96 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -24,9 +24,9 @@ use std::{ use crate::{Expr, LogicalPlan, Volatility}; -use arrow::datatypes::DataType; use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, TableReference}; use sqlparser::ast::Ident; +use datafusion_common::logical_type::LogicalType; /// Various types of DDL (CREATE / DROP) catalog manipulation #[derive(Clone, PartialEq, Eq, Hash)] @@ -322,7 +322,7 @@ pub struct CreateFunction { pub temporary: bool, pub name: String, pub args: Option>, - pub return_type: Option, + pub return_type: Option, pub params: CreateFunctionBody, /// Dummy schema pub schema: DFSchemaRef, @@ -332,7 +332,7 @@ pub struct OperateFunctionArg { // TODO: figure out how to support mode // pub mode: Option, pub name: Option, - pub data_type: DataType, + pub data_type: LogicalType, pub default_expr: Option, } #[derive(Clone, PartialEq, Eq, Hash, Debug)] diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index c9eef9bd34cc..04124984af3f 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -20,10 +20,11 @@ use std::fmt::{self, Display}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchemaRef, TableReference}; - +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::LogicalType; +use datafusion_common::logical_type::schema::LogicalSchema; use crate::LogicalPlan; /// Operator that copies the contents of a database to file(s) @@ -130,8 +131,6 @@ impl Display for WriteOp { fn make_count_schema() -> DFSchemaRef { Arc::new( - Schema::new(vec![Field::new("count", DataType::UInt64, false)]) - .try_into() - .unwrap(), + LogicalSchema::new(vec![LogicalField::new("count", LogicalType::UInt64, false)]).try_into().unwrap() ) } diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index 2f581c1928f4..182018cbaf65 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -349,5 +349,5 @@ impl UserDefinedLogicalNode for T { } fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet { - schema.fields().iter().map(|f| f.name().clone()).collect() + schema.fields().iter().map(|f| f.name().to_string()).collect() } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 31f830a6a13d..3285f8809d6d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -40,7 +40,7 @@ use crate::{ TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -55,6 +55,9 @@ use crate::display::PgJsonVisitor; use crate::logical_plan::tree_node::unwrap_arc; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::LogicalType; +use datafusion_common::logical_type::schema::{LogicalSchema, LogicalSchemaRef}; /// A `LogicalPlan` is a node in a tree of relational operators (such as /// Projection or Filter). @@ -351,20 +354,20 @@ impl LogicalPlan { } /// Returns the (fixed) output schema for explain plans - pub fn explain_schema() -> SchemaRef { - SchemaRef::new(Schema::new(vec![ + pub fn explain_schema() -> LogicalSchemaRef { + LogicalSchemaRef::new(Schema::new(vec![ Field::new("plan_type", DataType::Utf8, false), Field::new("plan", DataType::Utf8, false), - ])) + ]).into()) } /// Returns the (fixed) output schema for `DESCRIBE` plans - pub fn describe_schema() -> Schema { + pub fn describe_schema() -> LogicalSchema { Schema::new(vec![ Field::new("column_name", DataType::Utf8, false), Field::new("data_type", DataType::Utf8, false), Field::new("is_nullable", DataType::Utf8, false), - ]) + ]).into() } /// Returns all expressions (non-recursively) evaluated by the current @@ -1388,8 +1391,8 @@ impl LogicalPlan { /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes pub fn get_parameter_types( &self, - ) -> Result>, DataFusionError> { - let mut param_types: HashMap> = HashMap::new(); + ) -> Result>, DataFusionError> { + let mut param_types: HashMap> = HashMap::new(); self.apply_with_subqueries(|plan| { plan.apply_expressions(|expr| { @@ -2085,7 +2088,7 @@ impl SubqueryAlias { // functional dependencies: let func_dependencies = plan.schema().functional_dependencies().clone(); let schema = DFSchemaRef::new( - DFSchema::try_from_qualified_schema(alias.clone(), &schema)? + DFSchema::try_from_qualified_schema(alias.clone(), &schema.into())? .with_functional_dependencies(func_dependencies)?, ); Ok(SubqueryAlias { @@ -2124,7 +2127,7 @@ impl Filter { // construction (such as with correlated subqueries) so we make a best effort here and // ignore errors resolving the expression against the schema. if let Ok(predicate_type) = predicate.get_type(input.schema()) { - if predicate_type != DataType::Boolean { + if predicate_type != LogicalType::Boolean { return plan_err!( "Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}" ); @@ -2257,7 +2260,7 @@ pub struct Window { impl Window { /// Create a new window operator. pub fn try_new(window_expr: Vec, input: Arc) -> Result { - let fields: Vec<(Option, Arc)> = input + let fields: Vec<(Option, Arc)> = input .schema() .iter() .map(|(q, f)| (q.cloned(), f.clone())) @@ -2398,7 +2401,7 @@ impl TableScan { if table_name.table().is_empty() { return plan_err!("table_name cannot be empty"); } - let schema = table_source.schema(); + let schema: LogicalSchema = table_source.schema().as_ref().clone().into(); let func_dependencies = FunctionalDependencies::new_from_constraints( table_source.constraints(), schema.fields.len(), @@ -2412,7 +2415,7 @@ impl TableScan { let df_schema = DFSchema::new_with_metadata( p.iter() .map(|i| { - (Some(table_name.clone()), Arc::new(schema.field(*i).clone())) + (Some(table_name.clone()), schema.field(*i).clone()) }) .collect(), schema.metadata.clone(), @@ -2473,7 +2476,7 @@ pub struct Prepare { /// The name of the statement pub name: String, /// Data types of the parameters ([`Expr::Placeholder`]) - pub data_types: Vec, + pub data_types: Vec, /// The logical plan of the statements pub input: Arc, } @@ -3494,7 +3497,7 @@ digraph { let schema = Arc::new( DFSchema::try_from_qualified_schema( TableReference::bare("tab"), - &source.schema(), + &source.schema().as_ref().clone().into(), ) .unwrap(), ); diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index ccf45ff0d048..8c647f16bd75 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -17,9 +17,8 @@ //! Structs and traits to provide the information needed for expression simplification. -use arrow::datatypes::DataType; use datafusion_common::{DFSchemaRef, DataFusionError, Result}; - +use datafusion_common::logical_type::LogicalType; use crate::{execution_props::ExecutionProps, Expr, ExprSchemable}; /// Provides the information necessary to apply algebraic simplification to an @@ -39,7 +38,7 @@ pub trait SimplifyInfo { fn execution_props(&self) -> &ExecutionProps; /// Returns data type of this expr needed for determining optimized int type of a value - fn get_data_type(&self, expr: &Expr) -> Result; + fn get_data_type(&self, expr: &Expr) -> Result; } /// Provides simplification information based on DFSchema and @@ -75,7 +74,7 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { /// returns true if this Expr has boolean type fn is_boolean_type(&self, expr: &Expr) -> Result { for schema in &self.schema { - if let Ok(DataType::Boolean) = expr.get_type(schema) { + if let Ok(LogicalType::Boolean) = expr.get_type(schema) { return Ok(true); } } @@ -94,7 +93,7 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { } /// Returns data type of this expr needed for determining optimized int type of a value - fn get_data_type(&self, expr: &Expr) -> Result { + fn get_data_type(&self, expr: &Expr) -> Result { let schema = self.schema.as_ref().ok_or_else(|| { DataFusionError::Internal( "attempt to get data type without schema".to_string(), diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 5645a2a4dede..0f144c6ca89b 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +// TODO(@notfilippo): make most of these accept LogicalType + //! Coercion rules for matching argument types for binary operators use std::collections::HashSet; diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 5f060a4a4f16..e6f89efb9dfa 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -585,7 +585,7 @@ fn coerced_from<'a>( // List or LargeList with different dimensions should be handled in TypeSignature or other places before this (List(_) | LargeList(_), _) if datafusion_common::utils::base_type(type_from).eq(&Null) - || list_ndims(type_from) == list_ndims(type_into) => + || list_ndims(&type_from) == list_ndims(&type_into) => { Some(type_into.clone()) } diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index 86005da3dafa..275700adc9a1 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -36,52 +36,53 @@ pub mod binary; pub mod functions; pub mod other; -use arrow::datatypes::DataType; +use datafusion_common::logical_type::LogicalType; + /// Determine whether the given data type `dt` represents signed numeric values. -pub fn is_signed_numeric(dt: &DataType) -> bool { +pub fn is_signed_numeric(dt: &LogicalType) -> bool { matches!( dt, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float16 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _), + LogicalType::Int8 + | LogicalType::Int16 + | LogicalType::Int32 + | LogicalType::Int64 + | LogicalType::Float16 + | LogicalType::Float32 + | LogicalType::Float64 + | LogicalType::Decimal128(_, _) + | LogicalType::Decimal256(_, _), ) } /// Determine whether the given data type `dt` is `Null`. -pub fn is_null(dt: &DataType) -> bool { - *dt == DataType::Null +pub fn is_null(dt: &LogicalType) -> bool { + *dt == LogicalType::Null } /// Determine whether the given data type `dt` is a `Timestamp`. -pub fn is_timestamp(dt: &DataType) -> bool { - matches!(dt, DataType::Timestamp(_, _)) +pub fn is_timestamp(dt: &LogicalType) -> bool { + matches!(dt, LogicalType::Timestamp(_, _)) } /// Determine whether the given data type 'dt' is a `Interval`. -pub fn is_interval(dt: &DataType) -> bool { - matches!(dt, DataType::Interval(_)) +pub fn is_interval(dt: &LogicalType) -> bool { + matches!(dt, LogicalType::Interval(_)) } /// Determine whether the given data type `dt` is a `Date` or `Timestamp`. -pub fn is_datetime(dt: &DataType) -> bool { +pub fn is_datetime(dt: &LogicalType) -> bool { matches!( dt, - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) + LogicalType::Date32 | LogicalType::Date64 | LogicalType::Timestamp(_, _) ) } /// Determine whether the given data type `dt` is a `Utf8` or `LargeUtf8`. -pub fn is_utf8_or_large_utf8(dt: &DataType) -> bool { - matches!(dt, DataType::Utf8 | DataType::LargeUtf8) +pub fn is_utf8_or_large_utf8(dt: &LogicalType) -> bool { + matches!(dt, LogicalType::Utf8 | LogicalType::LargeUtf8) } /// Determine whether the given data type `dt` is a `Decimal`. -pub fn is_decimal(dt: &DataType) -> bool { - matches!(dt, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) +pub fn is_decimal(dt: &LogicalType) -> bool { + matches!(dt, LogicalType::Decimal128(_, _) | LogicalType::Decimal256(_, _)) } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 286f05309ea7..3a55297d39ba 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -28,7 +28,7 @@ use crate::{ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, }; -use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -39,6 +39,9 @@ use datafusion_common::{ }; use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::LogicalType; +use datafusion_common::logical_type::schema::LogicalSchema; /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions @@ -429,7 +432,7 @@ pub fn expand_qualified_wildcard( return plan_err!("Invalid qualifier {qualifier}"); } - let qualified_schema = Arc::new(Schema::new(fields_with_qualified)); + let qualified_schema = Arc::new(LogicalSchema::new(fields_with_qualified)); let qualified_dfschema = DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)? .with_functional_dependencies(projected_func_dependencies)?; @@ -727,7 +730,7 @@ pub fn from_plan( pub fn exprlist_to_fields<'a>( exprs: impl IntoIterator, plan: &LogicalPlan, -) -> Result, Arc)>> { +) -> Result, Arc)>> { // look for exact match in plan's output schema let input_schema = &plan.schema(); exprs @@ -830,40 +833,35 @@ pub(crate) fn find_column_indexes_referenced_by_expr( /// can this data type be used in hash join equal conditions?? /// data types here come from function 'equal_rows', if more data types are supported /// in equal_rows(hash join), add those data types here to generate join logical plan. -pub fn can_hash(data_type: &DataType) -> bool { +pub fn can_hash(data_type: &LogicalType) -> bool { match data_type { - DataType::Null => true, - DataType::Boolean => true, - DataType::Int8 => true, - DataType::Int16 => true, - DataType::Int32 => true, - DataType::Int64 => true, - DataType::UInt8 => true, - DataType::UInt16 => true, - DataType::UInt32 => true, - DataType::UInt64 => true, - DataType::Float32 => true, - DataType::Float64 => true, - DataType::Timestamp(time_unit, _) => match time_unit { + LogicalType::Null => true, + LogicalType::Boolean => true, + LogicalType::Int8 => true, + LogicalType::Int16 => true, + LogicalType::Int32 => true, + LogicalType::Int64 => true, + LogicalType::UInt8 => true, + LogicalType::UInt16 => true, + LogicalType::UInt32 => true, + LogicalType::UInt64 => true, + LogicalType::Float32 => true, + LogicalType::Float64 => true, + LogicalType::Timestamp(time_unit, _) => match time_unit { TimeUnit::Second => true, TimeUnit::Millisecond => true, TimeUnit::Microsecond => true, TimeUnit::Nanosecond => true, }, - DataType::Utf8 => true, - DataType::LargeUtf8 => true, - DataType::Decimal128(_, _) => true, - DataType::Date32 => true, - DataType::Date64 => true, - DataType::FixedSizeBinary(_) => true, - DataType::Dictionary(key_type, value_type) - if *value_type.as_ref() == DataType::Utf8 => - { - DataType::is_dictionary_key_type(key_type) - } - DataType::List(_) => true, - DataType::LargeList(_) => true, - DataType::FixedSizeList(_, _) => true, + LogicalType::Utf8 => true, + LogicalType::LargeUtf8 => true, + LogicalType::Decimal128(_, _) => true, + LogicalType::Date32 => true, + LogicalType::Date64 => true, + LogicalType::FixedSizeBinary(_) => true, + LogicalType::List(_) => true, + LogicalType::LargeList(_) => true, + LogicalType::FixedSizeList(_, _) => true, _ => false, } } @@ -1249,6 +1247,7 @@ impl AggregateOrderSensitivity { #[cfg(test)] mod tests { + use datafusion_common::logical_type::LogicalType; use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, @@ -1703,11 +1702,11 @@ mod tests { fn test_collect_expr() -> Result<()> { let mut accum: HashSet = HashSet::new(); expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &Expr::Cast(Cast::new(Box::new(col("a")), LogicalType::Float64)), &mut accum, )?; expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &Expr::Cast(Cast::new(Box::new(col("a")), LogicalType::Float64)), &mut accum, )?; assert_eq!(1, accum.len()); diff --git a/datafusion/expr/src/var_provider.rs b/datafusion/expr/src/var_provider.rs index e00cf7407237..b746955630b3 100644 --- a/datafusion/expr/src/var_provider.rs +++ b/datafusion/expr/src/var_provider.rs @@ -17,8 +17,8 @@ //! Variable provider -use arrow::datatypes::DataType; use datafusion_common::{Result, ScalarValue}; +use datafusion_common::logical_type::LogicalType; /// Variable type, system/user defined #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -35,7 +35,7 @@ pub trait VarProvider: std::fmt::Debug { fn get_value(&self, var_names: Vec) -> Result; /// Return the type of the given variable - fn get_type(&self, var_names: &[String]) -> Option; + fn get_type(&self, var_names: &[String]) -> Option; } pub fn is_system_variables(variable_names: &[String]) -> bool { diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 9c410d4e18e8..d3200a0a10d0 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -107,7 +107,7 @@ impl ScalarUDFImpl for ArrowCastFunc { info: &dyn SimplifyInfo, ) -> Result { // convert this into a real cast - let target_type = data_type_from_args(&args)?; + let target_type = data_type_from_args(&args)?.into(); // remove second (type) argument args.pop().unwrap(); let arg = args.pop().unwrap(); @@ -130,6 +130,8 @@ impl ScalarUDFImpl for ArrowCastFunc { /// Returns the requested type from the arguments fn data_type_from_args(args: &[Expr]) -> Result { + // TODO(@notfilippo): maybe parse LogicalType? + if args.len() != 2 { return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index b76da15c52ca..8aaae7b740d2 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -26,6 +26,8 @@ use datafusion_common::{ use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use datafusion_common::logical_type::extension::ExtensionType; +use datafusion_common::logical_type::LogicalType; #[derive(Debug)] pub struct GetFieldFunc { @@ -105,35 +107,36 @@ impl ScalarUDFImpl for GetFieldFunc { ); } }; + // TODO(@notfilippo): avoid converting to physical type let data_type = args[0].get_type(schema)?; match (data_type, name) { - (DataType::Map(fields, _), _) => { + (LogicalType::Map(fields, _), _) => { match fields.data_type() { - DataType::Struct(fields) if fields.len() == 2 => { + LogicalType::Struct(fields) if fields.len() == 2 => { // Arrow's MapArray is essentially a ListArray of structs with two columns. They are // often named "key", and "value", but we don't require any specific naming here; // instead, we assume that the second columnis the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(value_field.data_type().clone()) + Ok(value_field.data_type().physical_type()) }, _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), } } - (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { + (LogicalType::Struct(fields), ScalarValue::Utf8(Some(s))) => { if s.is_empty() { plan_err!( "Struct based indexed access requires a non empty string" ) } else { let field = fields.iter().find(|f| f.name() == s); - field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.data_type().clone()) + field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.data_type().clone().physical_type()) } } - (DataType::Struct(_), _) => plan_err!( + (LogicalType::Struct(_), _) => plan_err!( "Only UTF8 strings are valid as an indexed field in a struct" ), - (DataType::Null, _) => Ok(DataType::Null), + (LogicalType::Null, _) => Ok(DataType::Null), (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, `Map` or `Null` types, got {other}"), } } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 8ccda977f3a4..eeadd18c6215 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -22,6 +22,7 @@ use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; +use datafusion_common::logical_type::extension::ExtensionType; /// put values in a struct array. fn named_struct_expr(args: &[ColumnarValue]) -> Result { @@ -139,7 +140,7 @@ impl ScalarUDFImpl for NamedStructFunc { let value = &chunk[1]; if let Expr::Literal(ScalarValue::Utf8(Some(name))) = name { - Ok(Field::new(name, value.get_type(schema)?, true)) + Ok(Field::new(name, value.get_type(schema)?.physical_type(), true)) } else { exec_err!("named_struct even arguments must be string literals, got {name} instead at position {}", i * 2) } diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 0791561539e1..0d4d66eb5e42 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -28,6 +28,7 @@ use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, ScalarValue, }; +use datafusion_common::logical_type::extension::ExtensionType; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; @@ -158,6 +159,7 @@ impl ScalarUDFImpl for LogFunc { Ok(ColumnarValue::Array(arr)) } + // TODO(@notfilippo): avoid converting to physical type /// Simplify the `log` function by the relevant rules: /// 1. Log(a, 1) ===> 0 /// 2. Log(a, Power(a, b)) ===> b @@ -182,13 +184,13 @@ impl ScalarUDFImpl for LogFunc { let base = if let Some(base) = args.pop() { base } else { - lit(ScalarValue::new_ten(&number_datatype)?) + lit(ScalarValue::new_ten(&number_datatype.physical_type())?) }; match number { - Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => { + Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype.physical_type())? => { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( - &info.get_data_type(&base)?, + &info.get_data_type(&base)?.physical_type(), )?))) } Expr::ScalarFunction(ScalarFunction { func, mut args }) @@ -200,7 +202,7 @@ impl ScalarUDFImpl for LogFunc { number => { if number == base { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one( - &number_datatype, + &number_datatype.physical_type(), )?))) } else { let args = match num_args { diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 5b790fb56ddf..77a131997c41 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -32,7 +32,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; - +use datafusion_common::logical_type::extension::ExtensionType; use super::log::LogFunc; #[derive(Debug)] @@ -127,6 +127,7 @@ impl ScalarUDFImpl for PowerFunc { Ok(ColumnarValue::Array(arr)) } + // TODO(@notfilippo): avoid converting to physical type /// Simplify the `power` function by the relevant rules: /// 1. Power(a, 0) ===> 0 /// 2. Power(a, 1) ===> a @@ -143,11 +144,11 @@ impl ScalarUDFImpl for PowerFunc { plan_datafusion_err!("Expected power to have 2 arguments, got 1") })?; - let exponent_type = info.get_data_type(&exponent)?; + let exponent_type = info.get_data_type(&exponent)?.physical_type(); match exponent { Expr::Literal(value) if value == ScalarValue::new_zero(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::new_one(&info.get_data_type(&base)?)?, + ScalarValue::new_one(&info.get_data_type(&base)?.physical_type())?, ))) } Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => { diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 34f9802b1fd9..1074c04aa395 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -98,7 +98,6 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { mod tests { use super::*; use crate::test::*; - use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; use datafusion_expr::{ @@ -108,7 +107,7 @@ mod tests { }; use datafusion_functions_aggregate::count::count_udaf; use std::sync::Arc; - + use datafusion_common::logical_type::LogicalType; use datafusion_functions_aggregate::expr_fn::{count, sum}; fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -192,7 +191,7 @@ mod tests { .filter( scalar_subquery(Arc::new( LogicalPlanBuilder::from(table_scan_t2) - .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? + .filter(out_ref_col(LogicalType::UInt32, "t1.a").eq(col("t2.a")))? .aggregate( Vec::::new(), vec![count(lit(COUNT_STAR_EXPANSION))], diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index 098c934bf7e1..86d106bb8975 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -53,7 +53,7 @@ impl ApplyFunctionRewrites { if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), - &ts.source.schema(), + &ts.source.schema().as_ref().clone().into(), )?; schema.merge(&source_schema); } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 51ec8d8af1d3..9671f849ce2a 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -27,6 +27,8 @@ use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_common::logical_type::extension::ExtensionType; +use datafusion_common::logical_type::LogicalType; use datafusion_expr::expr::{ self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, WindowFunction, @@ -93,7 +95,7 @@ fn analyze_internal( if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), - &ts.source.schema(), + &ts.source.schema().as_ref().clone().into(), )?; schema.merge(&source_schema); } @@ -161,13 +163,13 @@ impl<'a> TypeCoercionRewriter<'a> { right: Expr, ) -> Result<(Expr, Expr)> { let (left_type, right_type) = get_input_types( - &left.get_type(self.schema)?, + &left.get_type(self.schema)?.physical_type(), op, - &right.get_type(self.schema)?, + &right.get_type(self.schema)?.physical_type(), )?; Ok(( - left.cast_to(&left_type, self.schema)?, - right.cast_to(&right_type, self.schema)?, + left.cast_to(&left_type.into(), self.schema)?, + right.cast_to(&right_type.into(), self.schema)?, )) } } @@ -210,7 +212,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data; let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); - let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( + let common_type = comparison_coercion(&expr_type.physical_type(), &subquery_type.physical_type()).ok_or(plan_datafusion_err!( "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" ), )?; @@ -219,8 +221,8 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { outer_ref_columns: subquery.outer_ref_columns, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, self.schema)?), - cast_subquery(new_subquery, &common_type)?, + Box::new(expr.cast_to(&common_type.clone().into(), self.schema)?), + cast_subquery(new_subquery, &common_type.into())?, negated, )))) } @@ -255,7 +257,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { }) => { let left_type = expr.get_type(self.schema)?; let right_type = pattern.get_type(self.schema)?; - let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { + let coerced_type = like_coercion(&left_type.physical_type(), &right_type.physical_type()).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" } else { @@ -266,10 +268,9 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { ) })?; let expr = match left_type { - DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr, - _ => Box::new(expr.cast_to(&coerced_type, self.schema)?), + _ => Box::new(expr.cast_to(&coerced_type.clone().into(), self.schema)?), }; - let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); + let pattern = Box::new(pattern.cast_to(&coerced_type.into(), self.schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, @@ -294,14 +295,14 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { }) => { let expr_type = expr.get_type(self.schema)?; let low_type = low.get_type(self.schema)?; - let low_coerced_type = comparison_coercion(&expr_type, &low_type) + let low_coerced_type = comparison_coercion(&expr_type.physical_type(), &low_type.physical_type()) .ok_or_else(|| { DataFusionError::Internal(format!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" )) })?; let high_type = high.get_type(self.schema)?; - let high_coerced_type = comparison_coercion(&expr_type, &low_type) + let high_coerced_type = comparison_coercion(&expr_type.physical_type(), &low_type.physical_type()) .ok_or_else(|| { DataFusionError::Internal(format!( "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" @@ -313,7 +314,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { DataFusionError::Internal(format!( "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) - })?; + })?.into(); Ok(Transformed::yes(Expr::Between(Between::new( Box::new(expr.cast_to(&coercion_type, self.schema)?), negated, @@ -326,24 +327,26 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { list, negated, }) => { + println!("{:?}", self.schema); let expr_data_type = expr.get_type(self.schema)?; let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(self.schema)) + .map(|list_expr| list_expr.get_type(self.schema).map(|t| t.physical_type())) .collect::>>()?; let result_type = - get_coerce_type_for_list(&expr_data_type, &list_data_types); + get_coerce_type_for_list(&expr_data_type.physical_type(), &list_data_types); match result_type { None => plan_err!( "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}" ), Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, self.schema)?; + let logical_coerced_type = coerced_type.into(); + let cast_expr = expr.cast_to(&logical_coerced_type, self.schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, self.schema) + list_expr.cast_to(&logical_coerced_type, self.schema) }) .collect::>>()?; Ok(Transformed::yes(Expr::InList(InList ::new( @@ -473,11 +476,11 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { /// Casts the given `value` to `target_type`. Note that this function /// only considers `Null` or `Utf8` values. -fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result { +fn coerce_scalar(target_type: &LogicalType, value: &ScalarValue) -> Result { match value { // Coerce Utf8 values: ScalarValue::Utf8(Some(val)) => { - ScalarValue::try_from_string(val.clone(), target_type) + ScalarValue::try_from_string(val.clone(), &target_type.physical_type()) } s => { if s.is_null() { @@ -500,7 +503,7 @@ fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result Result { coerce_scalar(target_type, &value).or_else(|err| { @@ -519,18 +522,18 @@ fn coerce_scalar_range_aware( /// This function returns the widest type in the family of `given_type`. /// If the given type is already the widest type, it returns `None`. /// For example, if `given_type` is `Int8`, it returns `Int64`. -fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> { +fn get_widest_type_in_family(given_type: &LogicalType) -> Option<&LogicalType> { match given_type { - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64), - DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64), - DataType::Float16 | DataType::Float32 => Some(&DataType::Float64), + LogicalType::UInt8 | LogicalType::UInt16 | LogicalType::UInt32 => Some(&LogicalType::UInt64), + LogicalType::Int8 | LogicalType::Int16 | LogicalType::Int32 => Some(&LogicalType::Int64), + LogicalType::Float16 | LogicalType::Float32 => Some(&LogicalType::Float64), _ => None, } } /// Coerces the given (window frame) `bound` to `target_type`. fn coerce_frame_bound( - target_type: &DataType, + target_type: &LogicalType, bound: WindowFrameBound, ) -> Result { match bound { @@ -561,11 +564,11 @@ fn coerce_window_frame( if let Some(col_type) = current_types.first() { if col_type.is_numeric() || is_utf8_or_large_utf8(col_type) - || matches!(col_type, DataType::Null) + || matches!(col_type, LogicalType::Null) { col_type } else if is_datetime(col_type) { - &DataType::Interval(IntervalUnit::MonthDayNano) + &LogicalType::Interval(IntervalUnit::MonthDayNano) } else { return internal_err!( "Cannot run range queries on datatype: {col_type:?}" @@ -575,7 +578,7 @@ fn coerce_window_frame( return internal_err!("ORDER BY column cannot be empty"); } } - WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, + WindowFrameUnits::Rows | WindowFrameUnits::Groups => &LogicalType::UInt64, }; window_frame.start_bound = coerce_frame_bound(target_type, window_frame.start_bound)?; window_frame.end_bound = coerce_frame_bound(target_type, window_frame.end_bound)?; @@ -586,8 +589,8 @@ fn coerce_window_frame( // The above op will be rewrite to the binary op when creating the physical op. fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result { let left_type = expr.get_type(schema)?; - get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; - expr.cast_to(&DataType::Boolean, schema) + get_input_types(&left_type.physical_type(), &Operator::IsDistinctFrom, &DataType::Boolean)?; + expr.cast_to(&LogicalType::Boolean, schema) } /// Returns `expressions` coerced to types compatible with @@ -605,15 +608,15 @@ fn coerce_arguments_for_signature_with_scalar_udf( let current_types = expressions .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.get_type(schema).map(|t| t.physical_type())) .collect::>>()?; let new_types = data_types_with_scalar_udf(¤t_types, func)?; expressions .into_iter() - .enumerate() - .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) + .zip(new_types) + .map(|(expr, t)| expr.cast_to(&t.into(), schema)) .collect() } @@ -632,15 +635,15 @@ fn coerce_arguments_for_signature_with_aggregate_udf( let current_types = expressions .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.get_type(schema).map(|t| t.physical_type())) .collect::>>()?; let new_types = data_types_with_aggregate_udf(¤t_types, func)?; expressions .into_iter() - .enumerate() - .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) + .zip(new_types) + .map(|(expr, t)| expr.cast_to(&t.into(), schema)) .collect() } @@ -655,8 +658,8 @@ fn coerce_arguments_for_fun( .into_iter() .map(|expr| { let data_type = expr.get_type(schema).unwrap(); - if let DataType::FixedSizeList(field, _) = data_type { - let to_type = DataType::List(field.clone()); + if let LogicalType::FixedSizeList(field, _) = data_type { + let to_type = LogicalType::List(field.clone()); expr.cast_to(&to_type, schema) } else { Ok(expr) @@ -682,7 +685,7 @@ fn coerce_agg_exprs_for_signature( } let current_types = input_exprs .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.get_type(schema).map(|t| t.physical_type())) .collect::>>()?; let coerced_types = @@ -690,8 +693,8 @@ fn coerce_agg_exprs_for_signature( input_exprs .into_iter() - .enumerate() - .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema)) + .zip(coerced_types) + .map(|(expr, t)| expr.cast_to(&t.into(), schema)) .collect() } @@ -735,12 +738,12 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { let then_types = case .when_then_expr .iter() - .map(|(_when, then)| then.get_type(schema)) + .map(|(_when, then)| then.get_type(schema).map(|t| t.physical_type())) .collect::>>()?; let else_type = case .else_expr .as_ref() - .map(|expr| expr.get_type(schema)) + .map(|expr| expr.get_type(schema).map(|t| t.physical_type())) .transpose()?; // find common coercible types @@ -750,10 +753,10 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { let when_types = case .when_then_expr .iter() - .map(|(when, _then)| when.get_type(schema)) + .map(|(when, _then)| when.get_type(schema).map(|t| t.physical_type())) .collect::>>()?; let coerced_type = - get_coerce_type_for_case_expression(&when_types, Some(case_type)); + get_coerce_type_for_case_expression(&when_types, Some(&case_type.physical_type())); coerced_type.ok_or_else(|| { plan_datafusion_err!( "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \ @@ -776,7 +779,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { let case_expr = case .expr .zip(case_when_coerce_type.as_ref()) - .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema)) + .map(|(case_expr, coercible_type)| case_expr.cast_to(&coercible_type.into(), schema)) .transpose()? .map(Box::new); let when_then = case @@ -784,7 +787,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { .into_iter() .map(|(when, then)| { let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean); - let when = when.cast_to(when_type, schema).map_err(|e| { + let when = when.cast_to(&when_type.into(), schema).map_err(|e| { DataFusionError::Context( format!( "WHEN expressions in CASE couldn't be \ @@ -793,13 +796,13 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { Box::new(e), ) })?; - let then = then.cast_to(&then_else_coerce_type, schema)?; + let then = then.cast_to(&then_else_coerce_type.clone().into(), schema)?; Ok((Box::new(when), Box::new(then))) }) .collect::>>()?; let else_expr = case .else_expr - .map(|expr| expr.cast_to(&then_else_coerce_type, schema)) + .map(|expr| expr.cast_to(&then_else_coerce_type.into(), schema)) .transpose()? .map(Box::new); @@ -816,6 +819,8 @@ mod test { use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; + use datafusion_common::logical_type::field::LogicalField; + use datafusion_common::logical_type::LogicalType; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection}; use datafusion_expr::test::function_stub::avg_udaf; @@ -839,12 +844,12 @@ mod test { })) } - fn empty_with_type(data_type: DataType) -> Arc { + fn empty_with_type(data_type: LogicalType) -> Arc { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::new( DFSchema::from_unqualifed_fields( - vec![Field::new("a", data_type, true)].into(), + vec![LogicalField::new("a", data_type, true)].into(), std::collections::HashMap::new(), ) .unwrap(), @@ -855,7 +860,7 @@ mod test { #[test] fn simple_case() -> Result<()> { let expr = col("a").lt(lit(2_u32)); - let empty = empty_with_type(DataType::Float64); + let empty = empty_with_type(LogicalType::Float64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) @@ -864,7 +869,7 @@ mod test { #[test] fn nested_case() -> Result<()> { let expr = col("a").lt(lit(2_u32)); - let empty = empty_with_type(DataType::Float64); + let empty = empty_with_type(LogicalType::Float64); let plan = LogicalPlan::Projection(Projection::try_new( vec![expr.clone().or(expr)], @@ -894,7 +899,7 @@ mod test { } fn return_type(&self, _args: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(Utf8) } fn invoke(&self, _args: &[ColumnarValue]) -> Result { @@ -1019,10 +1024,10 @@ mod test { let expected = "Projection: avg(Float64(12))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - let empty = empty_with_type(DataType::Int32); + let empty = empty_with_type(LogicalType::Int32); let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( avg_udaf(), - vec![cast(col("a"), DataType::Float64)], + vec![cast(col("a"), LogicalType::Float64)], false, None, None, @@ -1056,7 +1061,7 @@ mod test { #[test] fn binary_op_date32_op_interval() -> Result<()> { // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...") - let expr = cast(lit("1998-03-18"), DataType::Date32) + let expr = cast(lit("1998-03-18"), LogicalType::Date32) + lit(ScalarValue::new_interval_dt(123, 456)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); @@ -1070,7 +1075,7 @@ mod test { fn inlist_case() -> Result<()> { // a in (1,4,8), a is int64 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); - let empty = empty_with_type(DataType::Int64); + let empty = empty_with_type(LogicalType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Literal(Int32(1)), Literal(Int8(4)), Literal(Int64(8))]) })\ @@ -1082,7 +1087,7 @@ mod test { let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::from_unqualifed_fields( - vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(), + vec![LogicalField::new("a", LogicalType::Decimal128(12, 4), true)].into(), std::collections::HashMap::new(), )?), })); @@ -1098,10 +1103,10 @@ mod test { let expr = col("a").between( lit("2002-05-08"), // (cast('2002-05-08' as date) + interval '1 months') - cast(lit("2002-05-08"), DataType::Date32) + cast(lit("2002-05-08"), LogicalType::Date32) + lit(ScalarValue::new_interval_ym(0, 1)), ); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(LogicalType::Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); let expected = "Filter: a BETWEEN Utf8(\"2002-05-08\") AND CAST(CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AS Utf8)\ @@ -1113,11 +1118,11 @@ mod test { fn between_infer_cheap_type() -> Result<()> { let expr = col("a").between( // (cast('2002-05-08' as date) + interval '1 months') - cast(lit("2002-05-08"), DataType::Date32) + cast(lit("2002-05-08"), LogicalType::Date32) + lit(ScalarValue::new_interval_ym(0, 1)), lit("2002-12-08"), ); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(LogicalType::Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); // TODO: we should cast col(a). let expected = @@ -1130,13 +1135,13 @@ mod test { fn is_bool_for_type_coercion() -> Result<()> { // is true let expr = col("a").is_true(); - let empty = empty_with_type(DataType::Boolean); + let empty = empty_with_type(LogicalType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); let expected = "Projection: a IS TRUE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - let empty = empty_with_type(DataType::Int64); + let empty = empty_with_type(LogicalType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, ""); let err = ret.unwrap_err().to_string(); @@ -1144,21 +1149,21 @@ mod test { // is not true let expr = col("a").is_not_true(); - let empty = empty_with_type(DataType::Boolean); + let empty = empty_with_type(LogicalType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT TRUE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // is false let expr = col("a").is_false(); - let empty = empty_with_type(DataType::Boolean); + let empty = empty_with_type(LogicalType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS FALSE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // is not false let expr = col("a").is_not_false(); - let empty = empty_with_type(DataType::Boolean); + let empty = empty_with_type(LogicalType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT FALSE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1172,7 +1177,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(LogicalType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1180,7 +1185,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(LogicalType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE CAST(NULL AS Utf8) AS a LIKE NULL\ \n EmptyRelation"; @@ -1189,10 +1194,11 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Int64); + let empty = empty_with_type(LogicalType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); assert!(err.is_err()); + println!("{:?}", err); assert!(err.unwrap_err().to_string().contains( "There isn't a common type to coerce Int64 and Utf8 in LIKE expression" )); @@ -1201,7 +1207,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(LogicalType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1209,7 +1215,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(LogicalType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE CAST(NULL AS Utf8) AS a ILIKE NULL\ \n EmptyRelation"; @@ -1218,7 +1224,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Int64); + let empty = empty_with_type(LogicalType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); assert!(err.is_err()); @@ -1232,13 +1238,13 @@ mod test { fn unknown_for_type_coercion() -> Result<()> { // unknown let expr = col("a").is_unknown(); - let empty = empty_with_type(DataType::Boolean); + let empty = empty_with_type(LogicalType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); let expected = "Projection: a IS UNKNOWN\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(LogicalType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); let err = ret.unwrap_err().to_string(); @@ -1246,7 +1252,7 @@ mod test { // is not unknown let expr = col("a").is_not_unknown(); - let empty = empty_with_type(DataType::Boolean); + let empty = empty_with_type(LogicalType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT UNKNOWN\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1256,7 +1262,7 @@ mod test { #[test] fn concat_for_type_coercion() -> Result<()> { - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(LogicalType::Utf8); let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)]; // concat-type signature @@ -1279,34 +1285,34 @@ mod test { fn test_type_coercion_rewrite() -> Result<()> { // gt let schema = Arc::new(DFSchema::from_unqualifed_fields( - vec![Field::new("a", DataType::Int64, true)].into(), + vec![LogicalField::new("a", LogicalType::Int64, true)].into(), std::collections::HashMap::new(), )?); let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); - let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); + let expected = is_true(cast(lit(12i32), LogicalType::Int64).gt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // eq let schema = Arc::new(DFSchema::from_unqualifed_fields( - vec![Field::new("a", DataType::Int64, true)].into(), + vec![LogicalField::new("a", LogicalType::Int64, true)].into(), std::collections::HashMap::new(), )?); let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); - let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); + let expected = is_true(cast(lit(12i32), LogicalType::Int64).eq(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // lt let schema = Arc::new(DFSchema::from_unqualifed_fields( - vec![Field::new("a", DataType::Int64, true)].into(), + vec![LogicalField::new("a", LogicalType::Int64, true)].into(), std::collections::HashMap::new(), )?); let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); - let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); + let expected = is_true(cast(lit(12i32), LogicalType::Int64).lt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); @@ -1317,9 +1323,9 @@ mod test { fn binary_op_date32_eq_ts() -> Result<()> { let expr = cast( lit("1998-03-18"), - DataType::Timestamp(TimeUnit::Nanosecond, None), + LogicalType::Timestamp(TimeUnit::Nanosecond, None), ) - .eq(cast(lit("1998-03-18"), DataType::Date32)); + .eq(cast(lit("1998-03-18"), LogicalType::Date32)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); dbg!(&plan); @@ -1331,7 +1337,7 @@ mod test { fn cast_if_not_same_type( expr: Box, - data_type: &DataType, + data_type: &LogicalType, schema: &DFSchemaRef, ) -> Box { if &expr.get_type(schema).unwrap() != data_type { @@ -1343,8 +1349,8 @@ mod test { fn cast_helper( case: Case, - case_when_type: DataType, - then_else_type: DataType, + case_when_type: LogicalType, + then_else_type: LogicalType, schema: &DFSchemaRef, ) -> Case { let expr = case @@ -1375,23 +1381,23 @@ mod test { fn test_case_expression_coercion() -> Result<()> { let schema = Arc::new(DFSchema::from_unqualifed_fields( vec![ - Field::new("boolean", DataType::Boolean, true), - Field::new("integer", DataType::Int32, true), - Field::new("float", DataType::Float32, true), - Field::new( + LogicalField::new("boolean", LogicalType::Boolean, true), + LogicalField::new("integer", LogicalType::Int32, true), + LogicalField::new("float", LogicalType::Float32, true), + LogicalField::new( "timestamp", - DataType::Timestamp(TimeUnit::Nanosecond, None), + LogicalType::Timestamp(TimeUnit::Nanosecond, None), true, ), - Field::new("date", DataType::Date32, true), - Field::new( + LogicalField::new("date", LogicalType::Date32, true), + LogicalField::new( "interval", - DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano), + LogicalType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano), true, ), - Field::new("binary", DataType::Binary, true), - Field::new("string", DataType::Utf8, true), - Field::new("decimal", DataType::Decimal128(10, 10), true), + LogicalField::new("binary", LogicalType::Binary, true), + LogicalField::new("string", LogicalType::Utf8, true), + LogicalField::new("decimal", LogicalType::Decimal128(10, 10), true), ] .into(), std::collections::HashMap::new(), @@ -1406,8 +1412,8 @@ mod test { ], else_expr: None, }; - let case_when_common_type = DataType::Boolean; - let then_else_common_type = DataType::Utf8; + let case_when_common_type = LogicalType::Boolean; + let then_else_common_type = LogicalType::Utf8; let expected = cast_helper( case.clone(), case_when_common_type, @@ -1426,8 +1432,8 @@ mod test { ], else_expr: Some(Box::new(col("string"))), }; - let case_when_common_type = DataType::Utf8; - let then_else_common_type = DataType::Utf8; + let case_when_common_type = LogicalType::Utf8; + let then_else_common_type = LogicalType::Utf8; let expected = cast_helper( case.clone(), case_when_common_type, @@ -1484,7 +1490,7 @@ mod test { Operator::Plus, Box::new(cast( lit("2000-01-01T00:00:00"), - DataType::Timestamp(TimeUnit::Nanosecond, None), + LogicalType::Timestamp(TimeUnit::Nanosecond, None), )), )); let empty = empty(); @@ -1499,12 +1505,12 @@ mod test { let expr = Expr::BinaryExpr(BinaryExpr::new( Box::new(cast( lit("1998-03-18"), - DataType::Timestamp(TimeUnit::Nanosecond, None), + LogicalType::Timestamp(TimeUnit::Nanosecond, None), )), Operator::Minus, Box::new(cast( lit("1998-03-18"), - DataType::Timestamp(TimeUnit::Nanosecond, None), + LogicalType::Timestamp(TimeUnit::Nanosecond, None), )), )); let empty = empty(); @@ -1518,8 +1524,8 @@ mod test { #[test] fn in_subquery_cast_subquery() -> Result<()> { - let empty_int32 = empty_with_type(DataType::Int32); - let empty_int64 = empty_with_type(DataType::Int64); + let empty_int32 = empty_with_type(LogicalType::Int32); + let empty_int64 = empty_with_type(LogicalType::Int64); let in_subquery_expr = Expr::InSubquery(InSubquery::new( Box::new(col("a")), @@ -1543,8 +1549,8 @@ mod test { #[test] fn in_subquery_cast_expr() -> Result<()> { - let empty_int32 = empty_with_type(DataType::Int32); - let empty_int64 = empty_with_type(DataType::Int64); + let empty_int32 = empty_with_type(LogicalType::Int32); + let empty_int64 = empty_with_type(LogicalType::Int64); let in_subquery_expr = Expr::InSubquery(InSubquery::new( Box::new(col("a")), @@ -1567,8 +1573,8 @@ mod test { #[test] fn in_subquery_cast_all() -> Result<()> { - let empty_inside = empty_with_type(DataType::Decimal128(10, 5)); - let empty_outside = empty_with_type(DataType::Decimal128(8, 8)); + let empty_inside = empty_with_type(LogicalType::Decimal128(10, 5)); + let empty_outside = empty_with_type(LogicalType::Decimal128(8, 8)); let in_subquery_expr = Expr::InSubquery(InSubquery::new( Box::new(col("a")), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e760845e043a..2958082108a2 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -1094,6 +1094,8 @@ mod test { use std::iter; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::logical_type::field::LogicalField; + use datafusion_common::logical_type::LogicalType; use datafusion_expr::expr::AggregateFunction; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ @@ -1644,9 +1646,9 @@ mod test { let plan = table_scan(Some("table"), &schema, None) .unwrap() .filter( - cast(col("a"), DataType::Int64) + cast(col("a"), LogicalType::Int64) .lt(lit(1_i64)) - .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))), + .and(cast(col("a"), LogicalType::Int64).not_eq(lit(1_i64))), ) .unwrap() .build() @@ -1704,9 +1706,9 @@ mod test { let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]); let schema = DFSchema::from_unqualifed_fields( vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), + LogicalField::new("a", LogicalType::Int32, false), + LogicalField::new("b", LogicalType::Int32, false), + LogicalField::new("c", LogicalType::Int32, false), ] .into(), HashMap::default(), @@ -1723,8 +1725,8 @@ mod test { let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]); let schema = DFSchema::from_unqualifed_fields( vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), + LogicalField::new("a", LogicalType::Int32, false), + LogicalField::new("b", LogicalType::Int32, false), ] .into(), HashMap::default(), @@ -1791,7 +1793,7 @@ mod test { fn test_extract_expressions_from_col() -> Result<()> { let mut result = Vec::with_capacity(1); let schema = DFSchema::from_unqualifed_fields( - vec![Field::new("a", DataType::Int32, false)].into(), + vec![LogicalField::new("a", LogicalType::Int32, false)].into(), HashMap::default(), )?; extract_expressions(&col("a"), &schema, &mut result)?; diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 81d6dc863af6..7f009cbd1164 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -369,7 +369,7 @@ mod tests { use super::*; use crate::test::*; - use arrow::datatypes::DataType; + use datafusion_common::logical_type::LogicalType; use datafusion_expr::{and, binary_expr, col, lit, not, or, out_ref_col}; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -564,7 +564,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( col("orders.o_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), )? .project(vec![col("orders.o_custkey")])? .build()?, @@ -605,7 +605,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("lineitem")) .filter( col("lineitem.l_orderkey") - .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")), + .eq(out_ref_col(LogicalType::Int64, "orders.o_orderkey")), )? .project(vec![col("lineitem.l_orderkey")])? .build()?, @@ -616,7 +616,7 @@ mod tests { .filter( in_subquery(col("orders.o_orderkey"), lineitem).and( col("orders.o_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), ), )? .project(vec![col("orders.o_custkey")])? @@ -653,7 +653,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")) .and(col("o_orderkey").eq(lit(1))), )? @@ -688,8 +688,8 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + out_ref_col(LogicalType::Int64, "customer.c_custkey") + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), )? .project(vec![col("orders.o_custkey")])? .build()?, @@ -752,7 +752,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .not_eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -785,7 +785,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .lt(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -818,7 +818,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")) .or(col("o_orderkey").eq(lit(1))), )? @@ -876,7 +876,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -909,7 +909,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey").add(lit(1))])? @@ -942,7 +942,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])? @@ -971,7 +971,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -1008,7 +1008,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -1045,7 +1045,7 @@ mod tests { fn in_subquery_correlated() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) - .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").eq(col("sq.a")))? .project(vec![col("c")])? .build()?, ); @@ -1203,7 +1203,7 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( - out_ref_col(DataType::UInt32, "test.a") + out_ref_col(LogicalType::UInt32, "test.a") .eq(col("sq.a")) .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))), )? @@ -1238,8 +1238,8 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( - out_ref_col(DataType::UInt32, "test.a") - .add(out_ref_col(DataType::UInt32, "test.b")) + out_ref_col(LogicalType::UInt32, "test.a") + .add(out_ref_col(LogicalType::UInt32, "test.b")) .eq(col("sq.a").add(col("sq.b"))) .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))), )? @@ -1274,12 +1274,12 @@ mod tests { let subquery_scan2 = test_table_scan_with_name("sq2")?; let subquery1 = LogicalPlanBuilder::from(subquery_scan1) - .filter(out_ref_col(DataType::UInt32, "test.a").gt(col("sq1.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").gt(col("sq1.a")))? .project(vec![col("c") * lit(2u32)])? .build()?; let subquery2 = LogicalPlanBuilder::from(subquery_scan2) - .filter(out_ref_col(DataType::UInt32, "test.a").gt(col("sq2.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").gt(col("sq2.a")))? .project(vec![col("c") * lit(2u32)])? .build()?; @@ -1351,7 +1351,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( col("orders.o_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), )? .project(vec![col("orders.o_custkey")])? .build()?, @@ -1382,7 +1382,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("lineitem")) .filter( col("lineitem.l_orderkey") - .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")), + .eq(out_ref_col(LogicalType::Int64, "orders.o_orderkey")), )? .project(vec![col("lineitem.l_orderkey")])? .build()?, @@ -1393,7 +1393,7 @@ mod tests { .filter( exists(lineitem).and( col("orders.o_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), ), )? .project(vec![col("orders.o_custkey")])? @@ -1424,7 +1424,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")) .and(col("o_orderkey").eq(lit(1))), )? @@ -1452,7 +1452,7 @@ mod tests { fn exists_subquery_no_cols() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter(out_ref_col(DataType::Int64, "customer.c_custkey").eq(lit(1u32)))? + .filter(out_ref_col(LogicalType::Int64, "customer.c_custkey").eq(lit(1u32)))? .project(vec![col("orders.o_custkey")])? .build()?, ); @@ -1497,7 +1497,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .not_eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -1525,7 +1525,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .lt(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -1553,7 +1553,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")) .or(col("o_orderkey").eq(lit(1))), )? @@ -1582,7 +1582,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .build()?, @@ -1608,7 +1608,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey").add(lit(1))])? @@ -1636,7 +1636,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -1690,7 +1690,7 @@ mod tests { fn exists_subquery_correlated() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) - .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").eq(col("sq.a")))? .project(vec![col("c")])? .build()?, ); @@ -1741,12 +1741,12 @@ mod tests { let subquery_scan2 = test_table_scan_with_name("sq2")?; let subquery1 = LogicalPlanBuilder::from(subquery_scan1) - .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq1.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").eq(col("sq1.a")))? .project(vec![col("c")])? .build()?; let subquery2 = LogicalPlanBuilder::from(subquery_scan2) - .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq2.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").eq(col("sq2.a")))? .project(vec![col("c")])? .build()?; @@ -1780,7 +1780,7 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( (lit(1u32) + col("sq.a")) - .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)), + .gt(out_ref_col(LogicalType::UInt32, "test.a") * lit(2u32)), )? .project(vec![lit(1u32)])? .build()?; @@ -1832,7 +1832,7 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( (lit(1u32) + col("sq.a")) - .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)), + .gt(out_ref_col(LogicalType::UInt32, "test.a") * lit(2u32)), )? .project(vec![col("sq.c")])? .distinct()? @@ -1860,7 +1860,7 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( (lit(1u32) + col("sq.a")) - .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)), + .gt(out_ref_col(LogicalType::UInt32, "test.a") * lit(2u32)), )? .project(vec![col("sq.b") + col("sq.c")])? .distinct()? @@ -1888,7 +1888,7 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( (lit(1u32) + col("sq.a")) - .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)), + .gt(out_ref_col(LogicalType::UInt32, "test.a") * lit(2u32)), )? .project(vec![lit(1u32), col("sq.c")])? .distinct()? diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index edf6b72d7e17..d88491be1979 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -65,11 +65,12 @@ mod tests { use super::*; use crate::test::*; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ToDFSchema; use datafusion_expr::{ expr_rewriter::coerce_plan_expr_for_schema, logical_plan::table_scan, }; use std::sync::Arc; + use datafusion_common::logical_type::schema::LogicalSchema; + use datafusion_common::ToDFSchema; fn schema() -> Schema { Schema::new(vec![ @@ -108,7 +109,7 @@ mod tests { fn eliminate_one_union() -> Result<()> { let table_plan = coerce_plan_expr_for_schema( &table_scan(Some("table"), &schema(), None)?.build()?, - &schema().to_dfschema()?, + &LogicalSchema::from(schema()).to_dfschema()?, )?; let schema = table_plan.schema().clone(); let single_union_plan = LogicalPlan::Union(Union { diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index ccc637a0eb01..882499cfa68b 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -301,7 +301,7 @@ fn extract_non_nullable_columns( mod tests { use super::*; use crate::test::*; - use arrow::datatypes::DataType; + use datafusion_common::logical_type::LogicalType; use datafusion_expr::{ binary_expr, cast, col, lit, logical_plan::builder::LogicalPlanBuilder, @@ -427,9 +427,9 @@ mod tests { None, )? .filter(binary_expr( - cast(col("t1.b"), DataType::Int64).gt(lit(10u32)), + cast(col("t1.b"), LogicalType::Int64).gt(lit(10u32)), And, - try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)), + try_cast(col("t2.c"), LogicalType::Int64).lt(lit(20u32)), ))? .build()?; let expected = "\ diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 87d205139e8e..820ef149c347 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -156,11 +156,11 @@ fn split_eq_and_noneq_join_predicate( mod tests { use super::*; use crate::test::*; - use arrow::datatypes::DataType; use datafusion_expr::{ col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, }; use std::sync::Arc; + use datafusion_common::logical_type::LogicalType; fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -362,8 +362,8 @@ mod tests { // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2 let filter = Expr::eq( - col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?, - col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?, + col("t1.a") + lit(1i64).cast_to(&LogicalType::UInt32, &t1_schema)?, + col("t2.a") + lit(2i32).cast_to(&LogicalType::UInt32, &t2_schema)?, ) .alias("t1.a + 1 = t2.a + 2"); let plan = LogicalPlanBuilder::from(t1) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 2fbf77523bd1..e3d0bbf243c2 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -810,6 +810,8 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; + use datafusion_common::logical_type::field::LogicalField; + use datafusion_common::logical_type::LogicalType; use datafusion_expr::AggregateExt; use datafusion_expr::{ binary_expr, build_join_schema, @@ -1172,7 +1174,7 @@ mod tests { fn test_try_cast() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![try_cast(col("a"), DataType::Float64)])? + .project(vec![try_cast(col("a"), LogicalType::Float64)])? .build()?; let expected = "Projection: TRY_CAST(test.a AS Float64)\ @@ -1544,15 +1546,15 @@ mod tests { vec![ ( Some("test".into()), - Arc::new(Field::new("a", DataType::UInt32, false)) + Arc::new(LogicalField::new("a", LogicalType::UInt32, false)) ), ( Some("test".into()), - Arc::new(Field::new("b", DataType::UInt32, false)) + Arc::new(LogicalField::new("b", LogicalType::UInt32, false)) ), ( Some("test2".into()), - Arc::new(Field::new("c1", DataType::UInt32, true)) + Arc::new(LogicalField::new("c1", LogicalType::UInt32, true)) ), ], HashMap::new() @@ -1596,15 +1598,15 @@ mod tests { vec![ ( Some("test".into()), - Arc::new(Field::new("a", DataType::UInt32, false)) + Arc::new(LogicalField::new("a", LogicalType::UInt32, false)) ), ( Some("test".into()), - Arc::new(Field::new("b", DataType::UInt32, false)) + Arc::new(LogicalField::new("b", LogicalType::UInt32, false)) ), ( Some("test2".into()), - Arc::new(Field::new("c1", DataType::UInt32, true)) + Arc::new(LogicalField::new("c1", LogicalType::UInt32, true)) ), ], HashMap::new() @@ -1646,15 +1648,15 @@ mod tests { vec![ ( Some("test".into()), - Arc::new(Field::new("a", DataType::UInt32, false)) + Arc::new(LogicalField::new("a", LogicalType::UInt32, false)) ), ( Some("test".into()), - Arc::new(Field::new("b", DataType::UInt32, false)) + Arc::new(LogicalField::new("b", LogicalType::UInt32, false)) ), ( Some("test2".into()), - Arc::new(Field::new("a", DataType::UInt32, true)) + Arc::new(LogicalField::new("a", LogicalType::UInt32, true)) ), ], HashMap::new() @@ -1671,7 +1673,7 @@ mod tests { let projection = LogicalPlanBuilder::from(table_scan) .project(vec![Expr::Cast(Cast::new( Box::new(col("c")), - DataType::Float64, + LogicalType::Float64, ))])? .build()?; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 14e5ac141eeb..37fd448eff43 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -540,17 +540,17 @@ mod tests { "Optimizer rule 'get table_scan rule' failed\n\ caused by\nget table_scan rule\ncaused by\n\ Internal error: Failed due to a difference in schemas, \ - original schema: DFSchema { inner: Schema { \ + original schema: DFSchema { inner: LogicalSchema { \ fields: [], \ metadata: {} }, \ field_qualifiers: [], \ functional_dependencies: FunctionalDependencies { deps: [] } \ }, \ - new schema: DFSchema { inner: Schema { \ + new schema: DFSchema { inner: LogicalSchema { \ fields: [\ - Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ - Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ - Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }\ + LogicalField { name: \"a\", data_type: UInt32, nullable: false, metadata: {} }, \ + LogicalField { name: \"b\", data_type: UInt32, nullable: false, metadata: {} }, \ + LogicalField { name: \"c\", data_type: UInt32, nullable: false, metadata: {} }\ ], \ metadata: {} }, \ field_qualifiers: [Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" })], \ diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 63b357510f2f..f8be9b965f9a 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -246,9 +246,10 @@ fn empty_child(plan: &LogicalPlan) -> Result> { mod tests { use std::sync::Arc; - use arrow::datatypes::{DataType, Field, Schema}; + use arrow::datatypes::{DataType, Field, Fields, Schema}; use datafusion_common::{Column, DFSchema, JoinType, ScalarValue}; + use datafusion_common::logical_type::fields::LogicalFields; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, Operator, @@ -574,7 +575,7 @@ mod tests { fn test_empty_with_non_empty() -> Result<()> { let table_scan = test_table_scan()?; - let fields = test_table_scan_fields(); + let fields = LogicalFields::from(Fields::from(test_table_scan_fields())); let empty = LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index fa432ad76de5..91edc6844313 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1149,7 +1149,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; - + use datafusion_common::logical_type::schema::LogicalSchema; use datafusion_common::ScalarValue; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::table_scan; @@ -2387,7 +2387,7 @@ mod tests { table_name: "test".into(), filters: vec![], projected_schema: Arc::new(DFSchema::try_from( - (*test_provider.schema()).clone(), + LogicalSchema::from((*test_provider.schema()).clone()), )?), projection: None, source: Arc::new(test_provider), @@ -2459,7 +2459,7 @@ mod tests { table_name: "test".into(), filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], projected_schema: Arc::new(DFSchema::try_from( - (*test_provider.schema()).clone(), + LogicalSchema::from((*test_provider.schema()).clone()), )?), projection: Some(vec![0]), source: Arc::new(test_provider), @@ -2488,7 +2488,7 @@ mod tests { table_name: "test".into(), filters: vec![], projected_schema: Arc::new(DFSchema::try_from( - (*test_provider.schema()).clone(), + LogicalSchema::from((*test_provider.schema()).clone()), )?), projection: Some(vec![0]), source: Arc::new(test_provider), diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 0333cc8dde36..b79343f9b041 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -392,7 +392,7 @@ mod tests { use super::*; use crate::test::*; - use arrow::datatypes::DataType; + use datafusion_common::logical_type::LogicalType; use datafusion_expr::test::function_stub::sum; use datafusion_expr::{col, lit, max, min, out_ref_col, scalar_subquery, Between}; @@ -403,7 +403,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( col("orders.o_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? .project(vec![max(col("orders.o_custkey"))])? @@ -447,7 +447,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("lineitem")) .filter( col("lineitem.l_orderkey") - .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")), + .eq(out_ref_col(LogicalType::Int64, "orders.o_orderkey")), )? .aggregate( Vec::::new(), @@ -461,7 +461,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( col("orders.o_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")) + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")) .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))), )? .aggregate(Vec::::new(), vec![sum(col("orders.o_totalprice"))])? @@ -502,7 +502,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")) .and(col("o_orderkey").eq(lit(1))), )? @@ -540,8 +540,8 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + out_ref_col(LogicalType::Int64, "customer.c_custkey") + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? .project(vec![max(col("orders.o_custkey"))])? @@ -610,7 +610,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .not_eq(col("orders.o_custkey")), )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? @@ -637,7 +637,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .lt(col("orders.o_custkey")), )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? @@ -664,7 +664,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")) .or(col("o_orderkey").eq(lit(1))), )? @@ -713,7 +713,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? @@ -775,7 +775,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? @@ -814,7 +814,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? @@ -854,7 +854,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? @@ -893,7 +893,7 @@ mod tests { fn exists_subquery_correlated() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) - .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").eq(col("sq.a")))? .aggregate(Vec::::new(), vec![min(col("c"))])? .project(vec![min(col("c"))])? .build()?, @@ -989,7 +989,7 @@ mod tests { let sq1 = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .aggregate(Vec::::new(), vec![min(col("orders.o_custkey"))])? @@ -999,7 +999,7 @@ mod tests { let sq2 = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index f2c80e4a7207..03c7c304b963 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,6 +32,7 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::logical_type::extension::ExtensionType; use datafusion_expr::expr::{ AggregateFunctionDefinition, InList, InSubquery, WindowFunction, }; @@ -1050,7 +1051,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, + &info.get_data_type(&left)?.physical_type(), )?)) } @@ -1061,7 +1062,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, + &info.get_data_type(&left)?.physical_type(), )?)) } @@ -1136,7 +1137,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, + &info.get_data_type(&left)?.physical_type(), )?)) } @@ -1147,7 +1148,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, + &info.get_data_type(&left)?.physical_type(), )?)) } @@ -1222,7 +1223,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, + &info.get_data_type(&left)?.physical_type(), )?)) } @@ -1233,7 +1234,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, + &info.get_data_type(&left)?.physical_type(), )?)) } @@ -1245,7 +1246,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); Transformed::yes(if expr == *right { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) + Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?.physical_type())?) } else { expr }) @@ -1259,7 +1260,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); Transformed::yes(if expr == *left { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?.physical_type())?) } else { expr }) @@ -1783,7 +1784,9 @@ mod tests { ops::{BitAnd, BitOr, BitXor}, sync::Arc, }; - + use datafusion_common::logical_type::field::LogicalField; + use datafusion_common::logical_type::LogicalType; + use datafusion_common::logical_type::schema::LogicalSchema; use crate::simplify_expressions::SimplifyContext; use crate::test::test_table_scan_with_name; @@ -1822,9 +1825,9 @@ mod tests { } fn test_schema() -> DFSchemaRef { - Schema::new(vec![ - Field::new("i", DataType::Int64, false), - Field::new("b", DataType::Boolean, true), + LogicalSchema::new(vec![ + LogicalField::new("i", LogicalType::Int64, false), + LogicalField::new("b", LogicalType::Boolean, true), ]) .to_dfschema_ref() .unwrap() @@ -3007,14 +3010,14 @@ mod tests { Arc::new( DFSchema::from_unqualifed_fields( vec![ - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::Boolean, true), - Field::new("c3", DataType::Int64, true), - Field::new("c4", DataType::UInt32, true), - Field::new("c1_non_null", DataType::Utf8, false), - Field::new("c2_non_null", DataType::Boolean, false), - Field::new("c3_non_null", DataType::Int64, false), - Field::new("c4_non_null", DataType::UInt32, false), + LogicalField::new("c1", LogicalType::Utf8, true), + LogicalField::new("c2", LogicalType::Boolean, true), + LogicalField::new("c3", LogicalType::Int64, true), + LogicalField::new("c4", LogicalType::UInt32, true), + LogicalField::new("c1_non_null", LogicalType::Utf8, false), + LogicalField::new("c2_non_null", LogicalType::Boolean, false), + LogicalField::new("c3_non_null", LogicalType::Int64, false), + LogicalField::new("c4_non_null", LogicalType::UInt32, false), ] .into(), HashMap::new(), @@ -3102,7 +3105,7 @@ mod tests { #[test] fn simplify_expr_eq() { let schema = expr_test_schema(); - assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); + assert_eq!(col("c2").get_type(&schema).unwrap(), LogicalType::Boolean); // true = true -> true assert_eq!(simplify(lit(true).eq(lit(true))), lit(true)); @@ -3126,7 +3129,7 @@ mod tests { // expression to non-boolean. // // Make sure c1 column to be used in tests is not boolean type - assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8); + assert_eq!(col("c1").get_type(&schema).unwrap(), LogicalType::Utf8); // don't fold c1 = foo assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),); @@ -3136,7 +3139,7 @@ mod tests { fn simplify_expr_not_eq() { let schema = expr_test_schema(); - assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); + assert_eq!(col("c2").get_type(&schema).unwrap(), LogicalType::Boolean); // c2 != true -> !c2 assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),); @@ -3157,7 +3160,7 @@ mod tests { // when one of the operand is not of boolean type, folding the // other boolean constant will change return type of // expression to non-boolean. - assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8); + assert_eq!(col("c1").get_type(&schema).unwrap(), LogicalType::Utf8); assert_eq!( simplify(col("c1").not_eq(lit("foo"))), diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index e650d4c09c23..628da80bb3dd 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -93,7 +93,7 @@ impl SimplifyExpressions { // projection applied for simplification Arc::new(DFSchema::try_from_qualified_schema( scan.table_name.clone(), - &scan.source.schema(), + &scan.source.schema().as_ref().clone().into(), )?) } else { Arc::new(DFSchema::empty()) @@ -151,7 +151,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, Utc}; - + use datafusion_common::logical_type::LogicalType; use crate::optimizer::Optimizer; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::logical_plan::table_scan; @@ -445,7 +445,7 @@ mod tests { #[test] fn cast_expr() -> Result<()> { let table_scan = test_table_scan(); - let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))]; + let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), LogicalType::Int32))]; let plan = LogicalPlanBuilder::from(table_scan) .project(proj)? .build()?; @@ -703,9 +703,9 @@ mod tests { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; - let left_key = col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, t1.schema())?; + let left_key = col("t1.a") + lit(1i64).cast_to(&LogicalType::UInt32, t1.schema())?; let right_key = - col("t2.a") + lit(2i64).cast_to(&DataType::UInt32, t2.schema())?; + col("t2.a") + lit(2i64).cast_to(&LogicalType::UInt32, t2.schema())?; let plan = LogicalPlanBuilder::from(t1) .join_with_expr_keys( t2, diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 2c7e8644026e..dbabdb5c926f 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -59,7 +59,7 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { .schema() .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.name().to_string()) .collect(); assert_eq!(actual, expected); } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index fb18518fd226..0e19bf65e419 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -26,11 +26,12 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::utils::NamePreserver; use arrow::datatypes::{ - DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, + TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; +use datafusion_common::logical_type::LogicalType; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator}; @@ -104,7 +105,7 @@ impl OptimizerRule for UnwrapCastInComparison { if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), - &ts.source.schema(), + &ts.source.schema().as_ref().clone().into(), )?; schema.merge(&source_schema); } @@ -275,46 +276,39 @@ fn is_comparison_op(op: &Operator) -> bool { } /// Returns true if [UnwrapCastExprRewriter] supports this data type -fn is_supported_type(data_type: &DataType) -> bool { +fn is_supported_type(data_type: &LogicalType) -> bool { is_supported_numeric_type(data_type) || is_supported_string_type(data_type) - || is_supported_dictionary_type(data_type) } /// Returns true if [[UnwrapCastExprRewriter]] suppors this numeric type -fn is_supported_numeric_type(data_type: &DataType) -> bool { +fn is_supported_numeric_type(data_type: &LogicalType) -> bool { matches!( data_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Decimal128(_, _) - | DataType::Timestamp(_, _) + LogicalType::UInt8 + | LogicalType::UInt16 + | LogicalType::UInt32 + | LogicalType::UInt64 + | LogicalType::Int8 + | LogicalType::Int16 + | LogicalType::Int32 + | LogicalType::Int64 + | LogicalType::Decimal128(_, _) + | LogicalType::Timestamp(_, _) ) } /// Returns true if [UnwrapCastExprRewriter] supports casting this value as a string -fn is_supported_string_type(data_type: &DataType) -> bool { - matches!(data_type, DataType::Utf8 | DataType::LargeUtf8) -} - -/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a dictionary -fn is_supported_dictionary_type(data_type: &DataType) -> bool { - matches!(data_type, - DataType::Dictionary(_, inner) if is_supported_type(inner)) +fn is_supported_string_type(data_type: &LogicalType) -> bool { + matches!(data_type, LogicalType::Utf8 | LogicalType::LargeUtf8) } /// Convert a literal value from one data type to another fn try_cast_literal_to_type( lit_value: &ScalarValue, - target_type: &DataType, + target_type: &LogicalType, ) -> Option { - let lit_data_type = lit_value.data_type(); + let lit_data_type = lit_value.data_type().into(); if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { return None; } @@ -324,15 +318,14 @@ fn try_cast_literal_to_type( } try_cast_numeric_literal(lit_value, target_type) .or_else(|| try_cast_string_literal(lit_value, target_type)) - .or_else(|| try_cast_dictionary(lit_value, target_type)) } /// Convert a numeric value from one numeric data type to another fn try_cast_numeric_literal( lit_value: &ScalarValue, - target_type: &DataType, + target_type: &LogicalType, ) -> Option { - let lit_data_type = lit_value.data_type(); + let lit_data_type = lit_value.data_type().into(); if !is_supported_numeric_type(&lit_data_type) || !is_supported_numeric_type(target_type) { @@ -340,29 +333,29 @@ fn try_cast_numeric_literal( } let mul = match target_type { - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => 1_i128, - DataType::Timestamp(_, _) => 1_i128, - DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), + LogicalType::UInt8 + | LogicalType::UInt16 + | LogicalType::UInt32 + | LogicalType::UInt64 + | LogicalType::Int8 + | LogicalType::Int16 + | LogicalType::Int32 + | LogicalType::Int64 => 1_i128, + LogicalType::Timestamp(_, _) => 1_i128, + LogicalType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), _ => return None, }; let (target_min, target_max) = match target_type { - DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), - DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), - DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), - DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), - DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), - DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), - DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), - DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), - DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), - DataType::Decimal128(precision, _) => ( + LogicalType::UInt8 => (u8::MIN as i128, u8::MAX as i128), + LogicalType::UInt16 => (u16::MIN as i128, u16::MAX as i128), + LogicalType::UInt32 => (u32::MIN as i128, u32::MAX as i128), + LogicalType::UInt64 => (u64::MIN as i128, u64::MAX as i128), + LogicalType::Int8 => (i8::MIN as i128, i8::MAX as i128), + LogicalType::Int16 => (i16::MIN as i128, i16::MAX as i128), + LogicalType::Int32 => (i32::MIN as i128, i32::MAX as i128), + LogicalType::Int64 => (i64::MIN as i128, i64::MAX as i128), + LogicalType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), + LogicalType::Decimal128(precision, _) => ( // Different precision for decimal128 can store different range of value. // For example, the precision is 3, the max of value is `999` and the min // value is `-999` @@ -413,47 +406,47 @@ fn try_cast_numeric_literal( // the value casted from lit to the target type is in the range of target type. // return the target type of scalar value let result_scalar = match target_type { - DataType::Int8 => ScalarValue::Int8(Some(value as i8)), - DataType::Int16 => ScalarValue::Int16(Some(value as i16)), - DataType::Int32 => ScalarValue::Int32(Some(value as i32)), - DataType::Int64 => ScalarValue::Int64(Some(value as i64)), - DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), - DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), - DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), - DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), - DataType::Timestamp(TimeUnit::Second, tz) => { + LogicalType::Int8 => ScalarValue::Int8(Some(value as i8)), + LogicalType::Int16 => ScalarValue::Int16(Some(value as i16)), + LogicalType::Int32 => ScalarValue::Int32(Some(value as i32)), + LogicalType::Int64 => ScalarValue::Int64(Some(value as i64)), + LogicalType::UInt8 => ScalarValue::UInt8(Some(value as u8)), + LogicalType::UInt16 => ScalarValue::UInt16(Some(value as u16)), + LogicalType::UInt32 => ScalarValue::UInt32(Some(value as u32)), + LogicalType::UInt64 => ScalarValue::UInt64(Some(value as u64)), + LogicalType::Timestamp(TimeUnit::Second, tz) => { let value = cast_between_timestamp( lit_data_type, - DataType::Timestamp(TimeUnit::Second, tz.clone()), + LogicalType::Timestamp(TimeUnit::Second, tz.clone()), value, ); ScalarValue::TimestampSecond(value, tz.clone()) } - DataType::Timestamp(TimeUnit::Millisecond, tz) => { + LogicalType::Timestamp(TimeUnit::Millisecond, tz) => { let value = cast_between_timestamp( lit_data_type, - DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + LogicalType::Timestamp(TimeUnit::Millisecond, tz.clone()), value, ); ScalarValue::TimestampMillisecond(value, tz.clone()) } - DataType::Timestamp(TimeUnit::Microsecond, tz) => { + LogicalType::Timestamp(TimeUnit::Microsecond, tz) => { let value = cast_between_timestamp( lit_data_type, - DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + LogicalType::Timestamp(TimeUnit::Microsecond, tz.clone()), value, ); ScalarValue::TimestampMicrosecond(value, tz.clone()) } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + LogicalType::Timestamp(TimeUnit::Nanosecond, tz) => { let value = cast_between_timestamp( lit_data_type, - DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + LogicalType::Timestamp(TimeUnit::Nanosecond, tz.clone()), value, ); ScalarValue::TimestampNanosecond(value, tz.clone()) } - DataType::Decimal128(p, s) => { + LogicalType::Decimal128(p, s) => { ScalarValue::Decimal128(Some(value), *p, *s) } _ => { @@ -470,62 +463,36 @@ fn try_cast_numeric_literal( fn try_cast_string_literal( lit_value: &ScalarValue, - target_type: &DataType, + target_type: &LogicalType, ) -> Option { let string_value = match lit_value { ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => s.clone(), _ => return None, }; let scalar_value = match target_type { - DataType::Utf8 => ScalarValue::Utf8(string_value), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), + LogicalType::Utf8 => ScalarValue::Utf8(string_value), + LogicalType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), _ => return None, }; Some(scalar_value) } -/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary -fn try_cast_dictionary( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_value_type = lit_value.data_type(); - let result_scalar = match (lit_value, target_type) { - // Unwrap dictionary when inner type matches target type - (ScalarValue::Dictionary(_, inner_value), _) - if inner_value.data_type() == *target_type => - { - (**inner_value).clone() - } - // Wrap type when target type is dictionary - (_, DataType::Dictionary(index_type, inner_type)) - if **inner_type == lit_value_type => - { - ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) - } - _ => { - return None; - } - }; - Some(result_scalar) -} - /// Cast a timestamp value from one unit to another -fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option { +fn cast_between_timestamp(from: LogicalType, to: LogicalType, value: i128) -> Option { let value = value as i64; let from_scale = match from { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + LogicalType::Timestamp(TimeUnit::Second, _) => 1, + LogicalType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + LogicalType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + LogicalType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, _ => return Some(value), }; let to_scale = match to { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + LogicalType::Timestamp(TimeUnit::Second, _) => 1, + LogicalType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + LogicalType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + LogicalType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, _ => return Some(value), }; @@ -543,7 +510,8 @@ mod tests { use super::*; use arrow::compute::{cast_with_options, CastOptions}; - use arrow::datatypes::Field; + use datafusion_common::logical_type::extension::ExtensionType; + use datafusion_common::logical_type::field::LogicalField; use datafusion_common::tree_node::TransformedResult; use datafusion_expr::{cast, col, in_list, try_cast}; @@ -551,7 +519,7 @@ mod tests { fn test_not_unwrap_cast_comparison() { let schema = expr_test_schema(); // cast(INT32(c1), INT64) > INT64(c2) - let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2")); + let c1_gt_c2 = cast(col("c1"), LogicalType::Int64).gt(col("c2")); assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2); // INT32(c1) < INT32(16), the type is same @@ -559,7 +527,7 @@ mod tests { assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64)); + let expr_lt = cast(col("c1"), LogicalType::Int64).lt(lit(99999999999i64)); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); } @@ -568,25 +536,25 @@ mod tests { let schema = expr_test_schema(); // cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)); + let expr_lt = cast(col("c1"), LogicalType::Int64).lt(lit(16i64)); let expected = col("c1").lt(lit(16i32)); assert_eq!(optimize_test(expr_lt, &schema), expected); - let expr_lt = try_cast(col("c1"), DataType::Int64).lt(lit(16i64)); + let expr_lt = try_cast(col("c1"), LogicalType::Int64).lt(lit(16i64)); let expected = col("c1").lt(lit(16i32)); assert_eq!(optimize_test(expr_lt, &schema), expected); // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16) - let c2_eq_lit = cast(col("c2"), DataType::Int32).eq(lit(16i32)); + let c2_eq_lit = cast(col("c2"), LogicalType::Int32).eq(lit(16i32)); let expected = col("c2").eq(lit(16i64)); assert_eq!(optimize_test(c2_eq_lit, &schema), expected); // cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL) - let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64()); + let c1_lt_lit_null = cast(col("c1"), LogicalType::Int64).lt(null_i64()); let expected = col("c1").lt(null_i32()); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) - let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); + let lit_lt_lit = cast(null_i8(), LogicalType::Int32).lt(lit(12i32)); let expected = null_i8().lt(lit(12i8)); assert_eq!(optimize_test(lit_lt_lit, &schema), expected); } @@ -595,77 +563,38 @@ mod tests { fn test_unwrap_cast_comparison_unsigned() { // "cast(c6, UINT64) = 0u64 => c6 = 0u32 let schema = expr_test_schema(); - let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64)); + let expr_input = cast(col("c6"), LogicalType::UInt64).eq(lit(0u64)); let expected = col("c6").eq(lit(0u32)); assert_eq!(optimize_test(expr_input, &schema), expected); } - #[test] - fn test_unwrap_cast_comparison_string() { - let schema = expr_test_schema(); - let dict = ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("value")), - ); - - // cast(str1 as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = Utf8('value1') - let expr_input = cast(col("str1"), dict.data_type()).eq(lit(dict.clone())); - let expected = col("str1").eq(lit("value")); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value', 'Dictionary') - let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value")); - let expected = col("tag").eq(lit(dict.clone())); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // Verify reversed argument order - // arrow_cast('value', 'Dictionary') = cast(str1 as Dictionary) => Utf8('value1') = str1 - let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type())); - let expected = lit("value").eq(col("str1")); - assert_eq!(optimize_test(expr_input, &schema), expected); - } - - #[test] - fn test_unwrap_cast_comparison_large_string() { - let schema = expr_test_schema(); - // cast(largestr as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = LargeUtf8('value1') - let dict = ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))), - ); - let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict.clone())); - let expected = - col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned())))); - assert_eq!(optimize_test(expr_input, &schema), expected); - } - #[test] fn test_not_unwrap_cast_with_decimal_comparison() { let schema = expr_test_schema(); // integer to decimal: value is out of the bounds of the decimal // cast(c3, INT64) = INT64(100000000000000000) - let expr_eq = cast(col("c3"), DataType::Int64).eq(lit(100000000000000000i64)); + let expr_eq = cast(col("c3"), LogicalType::Int64).eq(lit(100000000000000000i64)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); // cast(c4, INT64) = INT64(1000) will overflow the i128 - let expr_eq = cast(col("c4"), DataType::Int64).eq(lit(1000i64)); + let expr_eq = cast(col("c4"), LogicalType::Int64).eq(lit(1000i64)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); // decimal to decimal: value will lose the scale when convert to the target data type // c3 = DECIMAL(12340,20,4) let expr_eq = - cast(col("c3"), DataType::Decimal128(20, 4)).eq(lit_decimal(12340, 20, 4)); + cast(col("c3"), LogicalType::Decimal128(20, 4)).eq(lit_decimal(12340, 20, 4)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); // decimal to integer // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type let expr_eq = - cast(col("c1"), DataType::Decimal128(10, 1)).eq(lit_decimal(123, 10, 1)); + cast(col("c1"), LogicalType::Decimal128(10, 1)).eq(lit_decimal(123, 10, 1)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type let expr_eq = - cast(col("c1"), DataType::Decimal128(10, 2)).eq(lit_decimal(1230, 10, 2)); + cast(col("c1"), LogicalType::Decimal128(10, 2)).eq(lit_decimal(1230, 10, 2)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); } @@ -674,32 +603,32 @@ mod tests { let schema = expr_test_schema(); // integer to decimal // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); - let expr_lt = try_cast(col("c3"), DataType::Int64).lt(lit(16i64)); + let expr_lt = try_cast(col("c3"), LogicalType::Int64).lt(lit(16i64)); let expected = col("c3").lt(lit_decimal(1600, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // c3 < INT64(NULL) - let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64()); + let c1_lt_lit_null = cast(col("c3"), LogicalType::Int64).lt(null_i64()); let expected = col("c3").lt(null_decimal(18, 2)); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); // decimal to decimal // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) let expr_lt = - cast(col("c3"), DataType::Decimal128(10, 0)).lt(lit_decimal(123, 10, 0)); + cast(col("c3"), LogicalType::Decimal128(10, 0)).lt(lit_decimal(123, 10, 0)); let expected = col("c3").lt(lit_decimal(12300, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) let expr_lt = - cast(col("c3"), DataType::Decimal128(10, 3)).lt(lit_decimal(1230, 10, 3)); + cast(col("c3"), LogicalType::Decimal128(10, 3)).lt(lit_decimal(1230, 10, 3)); let expected = col("c3").lt(lit_decimal(123, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // decimal to integer // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) let expr_lt = - cast(col("c1"), DataType::Decimal128(10, 2)).lt(lit_decimal(12300, 10, 2)); + cast(col("c1"), LogicalType::Decimal128(10, 2)).lt(lit_decimal(12300, 10, 2)); let expected = col("c1").lt(lit(123i32)); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -710,21 +639,21 @@ mod tests { // internal left type is not supported // FLOAT32(C5) in ... let expr_lt = - cast(col("c5"), DataType::Int64).in_list(vec![lit(12i64), lit(12i64)], false); + cast(col("c5"), LogicalType::Int64).in_list(vec![lit(12i64), lit(12i64)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12), Float32(12)) - let expr_lt = cast(col("c1"), DataType::Float32) + let expr_lt = cast(col("c1"), LogicalType::Float32) .in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // INT32(C1) in (INT64(99999999999), INT64(12)) - let expr_lt = cast(col("c1"), DataType::Int64) + let expr_lt = cast(col("c1"), LogicalType::Int64) .in_list(vec![lit(12i32), lit(99999999999i64)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3)) - let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list( + let expr_lt = cast(col("c3"), LogicalType::Decimal128(12, 3)).in_list( vec![ lit_decimal(12, 12, 3), lit_decimal(12, 12, 3), @@ -740,19 +669,19 @@ mod tests { let schema = expr_test_schema(); // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) let expr_lt = - cast(col("c1"), DataType::Int64).in_list(vec![lit(12i64), lit(24i64)], false); + cast(col("c1"), LogicalType::Int64).in_list(vec![lit(12i64), lit(24i64)], false); let expected = col("c1").in_list(vec![lit(12i32), lit(24i32)], false); assert_eq!(optimize_test(expr_lt, &schema), expected); // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) let expr_lt = - cast(col("c2"), DataType::Int32).in_list(vec![null_i32(), lit(14i32)], false); + cast(col("c2"), LogicalType::Int32).in_list(vec![null_i32(), lit(14i32)], false); let expected = col("c2").in_list(vec![null_i64(), lit(14i64)], false); assert_eq!(optimize_test(expr_lt, &schema), expected); // decimal test case // c3 is decimal(18,2) - let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list( + let expr_lt = cast(col("c3"), LogicalType::Decimal128(19, 3)).in_list( vec![ lit_decimal(12000, 19, 3), lit_decimal(24000, 19, 3), @@ -773,7 +702,7 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); // cast(INT32(12), INT64) IN (.....) - let expr_lt = cast(lit(12i32), DataType::Int64) + let expr_lt = cast(lit(12i32), LogicalType::Int64) .in_list(vec![lit(13i64), lit(12i64)], false); let expected = lit(12i32).in_list(vec![lit(13i32), lit(12i32)], false); assert_eq!(optimize_test(expr_lt, &schema), expected); @@ -784,7 +713,7 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) -> c1 < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).alias("x"); + let expr_lt = cast(col("c1"), LogicalType::Int64).lt(lit(16i64)).alias("x"); let expected = col("c1").lt(lit(16i32)).alias("x"); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -794,9 +723,9 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32) // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32 - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).or(cast( + let expr_lt = cast(col("c1"), LogicalType::Int64).lt(lit(16i64)).or(cast( col("c1"), - DataType::Int64, + LogicalType::Int64, ) .gt(lit(32i64))); let expected = col("c1").lt(lit(16i32)).or(col("c1").gt(lit(32i32))); @@ -809,12 +738,12 @@ mod tests { // but the type of c6 is uint32 // the rewriter will not throw error and just return the original expr let schema = expr_test_schema(); - let expr_input = cast(col("c6"), DataType::Float64).eq(lit(0f64)); + let expr_input = cast(col("c6"), LogicalType::Float64).eq(lit(0f64)); assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); // inlist for unsupported data type let expr_input = - in_list(cast(col("c6"), DataType::Float64), vec![lit(0f64)], false); + in_list(cast(col("c6"), LogicalType::Float64), vec![lit(0f64)], false); assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } @@ -841,17 +770,16 @@ mod tests { Arc::new( DFSchema::from_unqualifed_fields( vec![ - Field::new("c1", DataType::Int32, false), - Field::new("c2", DataType::Int64, false), - Field::new("c3", DataType::Decimal128(18, 2), false), - Field::new("c4", DataType::Decimal128(38, 37), false), - Field::new("c5", DataType::Float32, false), - Field::new("c6", DataType::UInt32, false), - Field::new("ts_nano_none", timestamp_nano_none_type(), false), - Field::new("ts_nano_utf", timestamp_nano_utc_type(), false), - Field::new("str1", DataType::Utf8, false), - Field::new("largestr", DataType::LargeUtf8, false), - Field::new("tag", dictionary_tag_type(), false), + LogicalField::new("c1", LogicalType::Int32, false), + LogicalField::new("c2", LogicalType::Int64, false), + LogicalField::new("c3", LogicalType::Decimal128(18, 2), false), + LogicalField::new("c4", LogicalType::Decimal128(38, 37), false), + LogicalField::new("c5", LogicalType::Float32, false), + LogicalField::new("c6", LogicalType::UInt32, false), + LogicalField::new("ts_nano_none", timestamp_nano_none_type(), false), + LogicalField::new("ts_nano_utf", timestamp_nano_utc_type(), false), + LogicalField::new("str1", LogicalType::Utf8, false), + LogicalField::new("largestr", LogicalType::LargeUtf8, false), ] .into(), HashMap::new(), @@ -889,19 +817,14 @@ mod tests { lit(ScalarValue::Decimal128(None, precision, scale)) } - fn timestamp_nano_none_type() -> DataType { - DataType::Timestamp(TimeUnit::Nanosecond, None) + fn timestamp_nano_none_type() -> LogicalType { + LogicalType::Timestamp(TimeUnit::Nanosecond, None) } // this is the type that now() returns - fn timestamp_nano_utc_type() -> DataType { + fn timestamp_nano_utc_type() -> LogicalType { let utc = Some("+0:00".into()); - DataType::Timestamp(TimeUnit::Nanosecond, utc) - } - - // a dictonary type for storing string tags - fn dictionary_tag_type() -> DataType { - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + LogicalType::Timestamp(TimeUnit::Nanosecond, utc) } #[test] @@ -926,7 +849,7 @@ mod tests { for s2 in &scalars { let expected_value = ExpectedCast::Value(s2.clone()); - expect_cast(s1.clone(), s2.data_type(), expected_value); + expect_cast(s1.clone(), s2.data_type().into(), expected_value); } } } @@ -951,28 +874,28 @@ mod tests { for s2 in &scalars { let expected_value = ExpectedCast::Value(s2.clone()); - expect_cast(s1.clone(), s2.data_type(), expected_value); + expect_cast(s1.clone(), s2.data_type().into(), expected_value); } } let max_i32 = ScalarValue::Int32(Some(i32::MAX)); expect_cast( max_i32, - DataType::UInt64, + LogicalType::UInt64, ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))), ); let min_i32 = ScalarValue::Int32(Some(i32::MIN)); expect_cast( min_i32, - DataType::Int64, + LogicalType::Int64, ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))), ); let max_i64 = ScalarValue::Int64(Some(i64::MAX)); expect_cast( max_i64, - DataType::UInt64, + LogicalType::UInt64, ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))), ); } @@ -984,28 +907,28 @@ mod tests { let max_i64 = ScalarValue::Int64(Some(i64::MAX)); let max_u64 = ScalarValue::UInt64(Some(u64::MAX)); - expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue); + expect_cast(max_i64.clone(), LogicalType::Int8, ExpectedCast::NoValue); - expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue); + expect_cast(max_i64.clone(), LogicalType::Int16, ExpectedCast::NoValue); - expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue); + expect_cast(max_i64, LogicalType::Int32, ExpectedCast::NoValue); - expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue); + expect_cast(max_u64, LogicalType::Int64, ExpectedCast::NoValue); - expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue); + expect_cast(min_i64, LogicalType::UInt64, ExpectedCast::NoValue); - expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue); + expect_cast(min_i32, LogicalType::UInt64, ExpectedCast::NoValue); // decimal out of range expect_cast( ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0), - DataType::Int64, + LogicalType::Int64, ExpectedCast::NoValue, ); expect_cast( ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1), - DataType::Int64, + LogicalType::Int64, ExpectedCast::NoValue, ); } @@ -1014,19 +937,19 @@ mod tests { fn test_try_decimal_cast_in_range() { expect_cast( ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(3, 0), + LogicalType::Decimal128(3, 0), ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)), ); expect_cast( ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(8, 0), + LogicalType::Decimal128(8, 0), ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)), ); expect_cast( ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(8, 5), + LogicalType::Decimal128(8, 5), ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)), ); } @@ -1036,14 +959,14 @@ mod tests { // decimal would lose precision expect_cast( ScalarValue::Decimal128(Some(12345), 5, 2), - DataType::Decimal128(3, 0), + LogicalType::Decimal128(3, 0), ExpectedCast::NoValue, ); // decimal would lose precision expect_cast( ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(2, 0), + LogicalType::Decimal128(2, 0), ExpectedCast::NoValue, ); } @@ -1084,11 +1007,11 @@ mod tests { // so double check it here assert_eq!(lit_tz_none, lit_tz_utc); - // e.g. DataType::Timestamp(_, None) - let dt_tz_none = lit_tz_none.data_type(); + // e.g. LogicalType::Timestamp(_, None) + let dt_tz_none: LogicalType = lit_tz_none.data_type().into(); - // e.g. DataType::Timestamp(_, Some(utc)) - let dt_tz_utc = lit_tz_utc.data_type(); + // e.g. LogicalType::Timestamp(_, Some(utc)) + let dt_tz_utc: LogicalType = lit_tz_utc.data_type().into(); // None <--> None expect_cast( @@ -1121,7 +1044,7 @@ mod tests { // timestamp to int64 expect_cast( lit_tz_utc.clone(), - DataType::Int64, + LogicalType::Int64, ExpectedCast::Value(ScalarValue::Int64(Some(12345))), ); @@ -1142,7 +1065,7 @@ mod tests { // timestamp to string (not supported yet) expect_cast( lit_tz_utc.clone(), - DataType::LargeUtf8, + LogicalType::LargeUtf8, ExpectedCast::NoValue, ); } @@ -1153,7 +1076,7 @@ mod tests { // int64 to list expect_cast( ScalarValue::Int64(Some(12345)), - DataType::List(Arc::new(Field::new("f", DataType::Int32, true))), + LogicalType::new_list(LogicalType::Int32, true), ExpectedCast::NoValue, ); } @@ -1171,7 +1094,7 @@ mod tests { /// casting is consistent with the Arrow kernels fn expect_cast( literal: ScalarValue, - target_type: DataType, + target_type: LogicalType, expected_result: ExpectedCast, ) { let actual_value = try_cast_literal_to_type(&literal, &target_type); @@ -1199,7 +1122,7 @@ mod tests { .expect("Failed to convert to array of size"); let cast_array = cast_with_options( &literal_array, - &target_type, + &target_type.physical_type(), &CastOptions::default(), ) .expect("Expected to be cast array with arrow cast kernel"); @@ -1212,9 +1135,9 @@ mod tests { // Verify that for timestamp types the timezones are the same // (ScalarValue::cmp doesn't account for timezones); if let ( - DataType::Timestamp(left_unit, left_tz), - DataType::Timestamp(right_unit, right_tz), - ) = (actual_value.data_type(), expected_value.data_type()) + LogicalType::Timestamp(left_unit, left_tz), + LogicalType::Timestamp(right_unit, right_tz), + ) = (actual_value.data_type().into(), expected_value.data_type().into()) { assert_eq!(left_unit, right_unit); assert_eq!(left_tz, right_tz); @@ -1234,7 +1157,7 @@ mod tests { // same timestamp let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), + &LogicalType::Timestamp(TimeUnit::Nanosecond, None), ) .unwrap(); @@ -1246,7 +1169,7 @@ mod tests { // TimestampNanosecond to TimestampMicrosecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Microsecond, None), + &LogicalType::Timestamp(TimeUnit::Microsecond, None), ) .unwrap(); @@ -1258,7 +1181,7 @@ mod tests { // TimestampNanosecond to TimestampMillisecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), + &LogicalType::Timestamp(TimeUnit::Millisecond, None), ) .unwrap(); @@ -1267,7 +1190,7 @@ mod tests { // TimestampNanosecond to TimestampSecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Second, None), + &LogicalType::Timestamp(TimeUnit::Second, None), ) .unwrap(); @@ -1276,7 +1199,7 @@ mod tests { // TimestampMicrosecond to TimestampNanosecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampMicrosecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), + &LogicalType::Timestamp(TimeUnit::Nanosecond, None), ) .unwrap(); @@ -1288,7 +1211,7 @@ mod tests { // TimestampMicrosecond to TimestampMillisecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampMicrosecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), + &LogicalType::Timestamp(TimeUnit::Millisecond, None), ) .unwrap(); @@ -1297,7 +1220,7 @@ mod tests { // TimestampMicrosecond to TimestampSecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampMicrosecond(Some(123456789), None), - &DataType::Timestamp(TimeUnit::Second, None), + &LogicalType::Timestamp(TimeUnit::Second, None), ) .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None)); @@ -1305,7 +1228,7 @@ mod tests { // TimestampMillisecond to TimestampNanosecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampMillisecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), + &LogicalType::Timestamp(TimeUnit::Nanosecond, None), ) .unwrap(); assert_eq!( @@ -1316,7 +1239,7 @@ mod tests { // TimestampMillisecond to TimestampMicrosecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampMillisecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Microsecond, None), + &LogicalType::Timestamp(TimeUnit::Microsecond, None), ) .unwrap(); assert_eq!( @@ -1326,7 +1249,7 @@ mod tests { // TimestampMillisecond to TimestampSecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampMillisecond(Some(123456789), None), - &DataType::Timestamp(TimeUnit::Second, None), + &LogicalType::Timestamp(TimeUnit::Second, None), ) .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None)); @@ -1334,7 +1257,7 @@ mod tests { // TimestampSecond to TimestampNanosecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampSecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), + &LogicalType::Timestamp(TimeUnit::Nanosecond, None), ) .unwrap(); assert_eq!( @@ -1345,7 +1268,7 @@ mod tests { // TimestampSecond to TimestampMicrosecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampSecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Microsecond, None), + &LogicalType::Timestamp(TimeUnit::Microsecond, None), ) .unwrap(); assert_eq!( @@ -1356,7 +1279,7 @@ mod tests { // TimestampSecond to TimestampMillisecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampSecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), + &LogicalType::Timestamp(TimeUnit::Millisecond, None), ) .unwrap(); assert_eq!( @@ -1367,7 +1290,7 @@ mod tests { // overflow let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampSecond(Some(i64::MAX), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), + &LogicalType::Timestamp(TimeUnit::Millisecond, None), ) .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); @@ -1384,33 +1307,8 @@ mod tests { for s2 in &scalars { let expected_value = ExpectedCast::Value(s2.clone()); - expect_cast(s1.clone(), s2.data_type(), expected_value); + expect_cast(s1.clone(), s2.data_type().into(), expected_value); } } } - #[test] - fn test_try_cast_to_dictionary_type() { - fn dictionary_type(t: DataType) -> DataType { - DataType::Dictionary(Box::new(DataType::Int32), Box::new(t)) - } - fn dictionary_value(value: ScalarValue) -> ScalarValue { - ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value)) - } - let scalars = vec![ - ScalarValue::from("string"), - ScalarValue::LargeUtf8(Some("string".to_owned())), - ]; - for s in &scalars { - expect_cast( - s.clone(), - dictionary_type(s.data_type()), - ExpectedCast::Value(dictionary_value(s.clone())), - ); - expect_cast( - dictionary_value(s.clone()), - s.data_type(), - ExpectedCast::Value(s.clone()), - ) - } - } } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index c501d5aaa4bf..33678a98da08 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -23,6 +23,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; +use datafusion_common::logical_type::LogicalType; use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; use datafusion_functions_aggregate::average::avg_udaf; @@ -399,7 +400,7 @@ impl ContextProvider for MyContextProvider { self.udafs.get(name).cloned() } - fn get_variable_type(&self, _variable_names: &[String]) -> Option { + fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index d5cd3c6f4af0..39305401288d 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -22,6 +22,7 @@ use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::Schema; use datafusion_common::{exec_err, Result}; +use datafusion_common::logical_type::extension::ExtensionType; use datafusion_expr::expr::Alias; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::Expr; @@ -127,7 +128,7 @@ pub fn limited_convert_logical_expr_to_physical_expr( cast_expr.expr.as_ref(), schema, )?, - cast_expr.data_type.clone(), + cast_expr.data_type.physical_type(), None, ))), Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 53c790ff6b54..98d501c97854 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -429,7 +429,7 @@ pub fn in_list( let expr_data_type = expr.data_type(schema)?; for list_expr in list.iter() { let list_expr_data_type = list_expr.data_type(schema)?; - if !DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type) { + if !DFSchema::datatype_is_logically_equal(&expr_data_type.clone().into(), &list_expr_data_type.clone().into()) { return internal_err!( "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" ); diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index aed2675e0447..f3ed3526b45a 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -160,7 +160,7 @@ pub fn negative( arg: Arc, input_schema: &Schema, ) -> Result> { - let data_type = arg.data_type(input_schema)?; + let data_type = arg.data_type(input_schema)?.into(); if is_null(&data_type) { Ok(arg) } else if !is_signed_numeric(&data_type) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8fe99cdca591..ae0f6e455532 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -24,9 +24,9 @@ use crate::{ }; use arrow::datatypes::Schema; -use datafusion_common::{ - exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, -}; +use datafusion_common::{exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema}; +use datafusion_common::logical_type::extension::ExtensionType; +use datafusion_common::logical_type::schema::LogicalSchema; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; @@ -259,12 +259,12 @@ pub fn create_physical_expr( Expr::Cast(Cast { expr, data_type }) => expressions::cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.clone(), + data_type.clone().physical_type(), ), Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.clone(), + data_type.clone().physical_type(), ), Expr::Not(expr) => { expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) @@ -359,7 +359,7 @@ where /// Convert a logical expression to a physical expression (without any simplification, etc) pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); + let df_schema = LogicalSchema::from(schema.clone()).to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); create_physical_expr(expr, &df_schema, &execution_props).unwrap() } @@ -378,7 +378,7 @@ mod tests { let expr = col("letter").eq(lit("A")); let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, false)]); - let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?; + let df_schema = DFSchema::try_from_qualified_schema("data", &schema.clone().into())?; let p = create_physical_expr(&expr, &df_schema, &ExecutionProps::new())?; let batch = RecordBatch::try_new( diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 070034116fb4..67d20837afdc 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -420,11 +420,10 @@ impl<'a> ColOpLit<'a> { #[cfg(test)] mod test { use std::sync::OnceLock; - + use arrow_schema::{DataType, Field, Schema, SchemaRef}; use super::*; use crate::planner::logical2physical; - use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_expr::expr_fn::*; use datafusion_expr::{lit, Expr}; @@ -835,7 +834,7 @@ mod test { fn test_analyze(expr: Expr, expected: Vec) { println!("Begin analyze of {expr}"); let schema = schema(); - let physical_expr = logical2physical(&expr, &schema); + let physical_expr = logical2physical(&expr, &schema.as_ref().clone().into()); let actual = LiteralGuarantee::analyze(&physical_expr); assert_eq!( diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index de9fede9ee86..6f54ac19626f 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -42,6 +42,8 @@ use datafusion_common::{ Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, DataFusionError, JoinSide, ScalarValue, Statistics, TableReference, }; +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::LogicalType; #[derive(Debug)] pub enum Error { @@ -158,10 +160,10 @@ impl TryFrom<&protobuf::DfSchema> for DFSchema { df_schema: &protobuf::DfSchema, ) -> datafusion_common::Result { let df_fields = df_schema.columns.clone(); - let qualifiers_and_fields: Vec<(Option, Arc)> = df_fields + let qualifiers_and_fields: Vec<(Option, Arc)> = df_fields .iter() .map(|df_field| { - let field: Field = df_field.field.as_ref().required("field")?; + let field: LogicalField = df_field.field.as_ref().required("field")?; Ok(( df_field .qualifier @@ -190,6 +192,16 @@ impl TryFrom for DFSchemaRef { } } +impl TryFrom<&protobuf::ArrowType> for LogicalType { + type Error = Error; + + fn try_from( + arrow_type: &protobuf::ArrowType, + ) -> datafusion_common::Result { + DataType::try_from(arrow_type).map(|t| t.into()) + } +} + impl TryFrom<&protobuf::ArrowType> for DataType { type Error = Error; @@ -332,6 +344,14 @@ impl TryFrom<&protobuf::Field> for Field { } } + +impl TryFrom<&protobuf::Field> for LogicalField { + type Error = Error; + fn try_from(field: &protobuf::Field) -> Result { + Field::try_from(field).map(|t| t.into()) + } +} + impl TryFrom<&protobuf::Schema> for Schema { type Error = Error; diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 877043f66809..d343554e6fe6 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -23,10 +23,7 @@ use crate::protobuf_common::{ }; use arrow::array::{ArrayRef, RecordBatch}; use arrow::csv::WriterBuilder; -use arrow::datatypes::{ - DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, - SchemaRef, TimeUnit, UnionMode, -}; +use arrow::datatypes::{DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode}; use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator}; use datafusion_common::{ config::{ @@ -39,6 +36,8 @@ use datafusion_common::{ Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, DataFusionError, JoinSide, ScalarValue, Statistics, }; +use datafusion_common::logical_type::extension::ExtensionType; +use datafusion_common::logical_type::LogicalType; #[derive(Debug)] pub enum Error { @@ -112,6 +111,17 @@ impl TryFrom<&DataType> for protobuf::ArrowType { } } +impl TryFrom<&LogicalType> for protobuf::ArrowType { + type Error = Error; + + fn try_from(val: &LogicalType) -> Result { + let arrow_type_enum: ArrowTypeEnum = (&val.physical_type()).try_into()?; + Ok(Self { + arrow_type_enum: Some(arrow_type_enum), + }) + } +} + impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { type Error = Error; @@ -262,8 +272,9 @@ impl TryFrom<&DFSchema> for protobuf::DfSchema { let columns = s .iter() .map(|(qualifier, field)| { + let field: Field = field.as_ref().clone().into(); Ok(protobuf::DfField { - field: Some(field.as_ref().try_into()?), + field: Some((&field).try_into()?), qualifier: qualifier.map(|r| protobuf::ColumnRelation { relation: r.to_string(), }), diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index cdb9d5260a0f..06e15d19b8f8 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -65,7 +65,9 @@ use datafusion_expr::{ use prost::bytes::BufMut; use prost::Message; - +use datafusion_common::logical_type::extension::ExtensionType; +use datafusion_common::logical_type::LogicalType; +use datafusion_proto_common::ArrowType; use self::to_proto::serialize_expr; pub mod file_formats; @@ -830,10 +832,10 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Prepare(prepare) => { let input: LogicalPlan = into_logical_plan!(prepare.input, ctx, extension_codec)?; - let data_types: Vec = prepare + let data_types: Vec = prepare .data_types .iter() - .map(DataType::try_from) + .map(|t| DataType::try_from(t).map(|t| t.into())) .collect::>()?; LogicalPlanBuilder::from(input) .prepare(prepare.name.clone(), data_types)? @@ -1554,8 +1556,8 @@ impl AsLogicalPlan for LogicalPlanNode { name: name.clone(), data_types: data_types .iter() - .map(|t| t.try_into()) - .collect::, _>>()?, + .map(|t| (&t.physical_type()).try_into()) + .collect::, _>>()?, input: Some(Box::new(input)), }, ))), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d54078b72bb7..03a996c12dcd 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,6 +26,8 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::LogicalType; use datafusion::datasource::file_format::arrow::ArrowFormatFactory; use datafusion::datasource::file_format::csv::CsvFormatFactory; use datafusion::datasource::file_format::format_as_file_type; @@ -583,7 +585,7 @@ async fn roundtrip_expr_api() -> Result<()> { // list of expressions to round trip let expr_list = vec![ - encode(col("a").cast_to(&DataType::Utf8, &schema)?, lit("hex")), + encode(col("a").cast_to(&LogicalType::Utf8, &schema)?, lit("hex")), decode(lit("1234"), lit("hex")), array_to_string(make_array(vec![lit(1), lit(2), lit(3)]), lit(",")), array_dims(make_array(vec![lit(1), lit(2), lit(3)])), @@ -691,7 +693,7 @@ async fn roundtrip_expr_api() -> Result<()> { bit_and(lit(2)), bit_or(lit(2)), bit_xor(lit(2)), - string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), + string_agg(col("a").cast_to(&LogicalType::Utf8, &schema)?, lit("|")), bool_and(lit(true)), bool_or(lit(true)), ]; @@ -1548,11 +1550,11 @@ fn roundtrip_schema() { fn roundtrip_dfschema() { let dfschema = DFSchema::new_with_metadata( vec![ - (None, Arc::new(Field::new("a", DataType::Int64, false))), + (None, Arc::new(LogicalField::new("a", LogicalType::Int64, false))), ( Some("t".into()), Arc::new( - Field::new("b", DataType::Decimal128(15, 2), true).with_metadata( + LogicalField::new("b", LogicalType::Decimal128(15, 2), true).with_metadata( HashMap::from([(String::from("k1"), String::from("v1"))]), ), ), @@ -1681,7 +1683,7 @@ fn roundtrip_null_literal() { #[test] fn roundtrip_cast() { - let test_expr = Expr::Cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + let test_expr = Expr::Cast(Cast::new(Box::new(lit(1.0_f32)), LogicalType::Boolean)); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1690,13 +1692,13 @@ fn roundtrip_cast() { #[test] fn roundtrip_try_cast() { let test_expr = - Expr::TryCast(TryCast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + Expr::TryCast(TryCast::new(Box::new(lit(1.0_f32)), LogicalType::Boolean)); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); let test_expr = - Expr::TryCast(TryCast::new(Box::new(lit("not a bool")), DataType::Boolean)); + Expr::TryCast(TryCast::new(Box::new(lit("not a bool")), LogicalType::Boolean)); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index aee4cf5a38ed..159bf46f97c6 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -30,6 +30,7 @@ use datafusion_sql::{ TableReference, }; use std::{collections::HashMap, sync::Arc}; +use datafusion_common::logical_type::LogicalType; fn main() { let sql = "SELECT \ @@ -132,7 +133,7 @@ impl ContextProvider for MyContextProvider { self.udafs.get(name).cloned() } - fn get_variable_type(&self, _variable_names: &[String]) -> Option { + fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index ea460cb3efc2..a1e4719d7c70 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -16,7 +16,6 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow_schema::DataType; use datafusion_common::{ internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, @@ -36,6 +35,7 @@ use sqlparser::ast::{ }; use std::str::FromStr; use strum::IntoEnumIterator; +use datafusion_common::logical_type::LogicalType; /// Suggest a valid function based on an invalid input function name pub fn suggest_valid_function( @@ -474,11 +474,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn check_unnest_arg(arg: &Expr, schema: &DFSchema) -> Result<()> { // Check argument type, array types are supported match arg.get_type(schema)? { - DataType::List(_) - | DataType::LargeList(_) - | DataType::FixedSizeList(_, _) - | DataType::Struct(_) => Ok(()), - DataType::Null => { + LogicalType::List(_) + | LogicalType::LargeList(_) + | LogicalType::FixedSizeList(_, _) + | LogicalType::Struct(_) => Ok(()), + LogicalType::Null => { not_impl_err!("unnest() does not support null yet") } _ => { diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index d297b2e4df5b..87419d97351e 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -16,13 +16,13 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow_schema::Field; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr}; use sqlparser::ast::{Expr as SQLExpr, Ident}; +use datafusion_common::logical_type::field::LogicalField; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_identifier_to_expr( @@ -280,7 +280,7 @@ fn search_dfschema<'ids, 'schema>( ids: &'ids [String], schema: &'schema DFSchema, ) -> Option<( - &'schema Field, + &'schema LogicalField, Option<&'schema TableReference>, &'ids [String], )> { diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a8af37ee6a37..04ecdce3f164 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::DataType; use arrow_schema::TimeUnit; use datafusion_common::utils::list_ndims; use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; @@ -24,6 +23,8 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, }; +use datafusion_common::logical_type::extension::ExtensionType; +use datafusion_common::logical_type::LogicalType; use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ @@ -113,8 +114,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if op == Operator::StringConcat { let left_type = left.get_type(schema)?; let right_type = right.get_type(schema)?; - let left_list_ndims = list_ndims(&left_type); - let right_list_ndims = list_ndims(&right_type); + let left_list_ndims = list_ndims(&left_type.physical_type()); + let right_list_ndims = list_ndims(&right_type.physical_type()); // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. @@ -351,12 +352,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // numeric constants are treated as seconds (rather as nanoseconds) // to align with postgres / duckdb semantics let expr = match &dt { - DataType::Timestamp(TimeUnit::Nanosecond, tz) - if expr.get_type(schema)? == DataType::Int64 => + LogicalType::Timestamp(TimeUnit::Nanosecond, tz) + if expr.get_type(schema)? == LogicalType::Int64 => { Expr::Cast(Cast::new( Box::new(expr), - DataType::Timestamp(TimeUnit::Second, tz.clone()), + LogicalType::Timestamp(TimeUnit::Second, tz.clone()), )) } _ => expr, @@ -635,7 +636,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?), match *time_zone { SQLExpr::Value(Value::SingleQuotedString(s)) => { - DataType::Timestamp(TimeUnit::Nanosecond, Some(s.into())) + LogicalType::Timestamp(TimeUnit::Nanosecond, Some(s.into())) } _ => { return not_impl_err!( @@ -813,7 +814,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; - if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { + if pattern_type != LogicalType::Utf8 && pattern_type != LogicalType::Null { return plan_err!("Invalid pattern in LIKE expression"); } let escape_char = if let Some(char) = escape_char { @@ -844,7 +845,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; - if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { + if pattern_type != LogicalType::Utf8 && pattern_type != LogicalType::Null { return plan_err!("Invalid pattern in SIMILAR TO expression"); } let escape_char = if let Some(char) = escape_char { @@ -1023,6 +1024,7 @@ mod tests { use std::sync::Arc; use arrow::datatypes::{Field, Schema}; + use arrow_schema::DataType; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -1074,7 +1076,7 @@ mod tests { None } - fn get_variable_type(&self, _variable_names: &[String]) -> Option { + fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index fa95fc2e051d..a9a2698397a9 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -18,7 +18,6 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; -use arrow_schema::DataType; use datafusion_common::{ not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; @@ -28,12 +27,13 @@ use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; use std::borrow::Cow; +use datafusion_common::logical_type::LogicalType; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn parse_value( &self, value: Value, - param_data_types: &[DataType], + param_data_types: &[LogicalType], ) -> Result { match value { Value::Number(n, _) => self.parse_sql_number(&n, false), @@ -96,7 +96,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on. fn create_placeholder_expr( param: String, - param_data_types: &[DataType], + param_data_types: &[LogicalType], ) -> Result { // Parse the placeholder as a number because it is the only support from sqlparser and postgres let index = param[1..].parse::(); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 63ef86446aaf..a62eb2f126ad 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -37,6 +37,10 @@ use datafusion_common::{ not_impl_err, plan_err, unqualified_field_not_found, DFSchema, DataFusionError, Result, }; +use datafusion_common::logical_type::field::LogicalField; +use datafusion_common::logical_type::fields::LogicalFields; +use datafusion_common::logical_type::LogicalType; +use datafusion_common::logical_type::schema::LogicalSchema; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::utils::find_column_exprs; use datafusion_expr::TableSource; @@ -83,7 +87,7 @@ pub trait ContextProvider { /// Getter for a UDWF fn get_window_meta(&self, name: &str) -> Option>; /// Getter for system/user-defined variable type - fn get_variable_type(&self, variable_names: &[String]) -> Option; + fn get_variable_type(&self, variable_names: &[String]) -> Option; /// Get configuration options fn options(&self) -> &ConfigOptions; @@ -156,7 +160,7 @@ impl IdentNormalizer { pub struct PlannerContext { /// Data types for numbered parameters ($1, $2, etc), if supplied /// in `PREPARE` statement - prepare_param_data_types: Arc>, + prepare_param_data_types: Arc>, /// Map of CTE name to logical plan of the WITH clause. /// Use `Arc` to allow cheap cloning ctes: HashMap>, @@ -183,7 +187,7 @@ impl PlannerContext { /// Update the PlannerContext with provided prepare_param_data_types pub fn with_prepare_param_data_types( mut self, - prepare_param_data_types: Vec, + prepare_param_data_types: Vec, ) -> Self { self.prepare_param_data_types = prepare_param_data_types.into(); self @@ -205,7 +209,7 @@ impl PlannerContext { } /// Return the types of parameters (`$1`, `$2`, etc) if known - pub fn prepare_param_data_types(&self) -> &[DataType] { + pub fn prepare_param_data_types(&self) -> &[LogicalType] { &self.prepare_param_data_types } @@ -257,7 +261,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - pub fn build_schema(&self, columns: Vec) -> Result { + pub fn build_schema(&self, columns: Vec) -> Result { let mut fields = Vec::with_capacity(columns.len()); for column in columns { @@ -266,14 +270,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .options .iter() .any(|x| x.option == ColumnOption::NotNull); - fields.push(Field::new( + fields.push(LogicalField::new( self.normalizer.normalize(column.name), data_type, !not_nullable, )); } - Ok(Schema::new(fields)) + Ok(LogicalSchema::new(fields)) } /// Returns a vector of (column_name, default_expr) pairs @@ -377,13 +381,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) } - pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { + pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { match sql_type { SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type, _)) => { // Arrays may be multi-dimensional. let inner_data_type = self.convert_data_type(inner_sql_type)?; - Ok(DataType::new_list(inner_data_type, true)) + Ok(LogicalType::new_list(inner_data_type, true).into()) } SQLDataType::Array(ArrayElemTypeDef::None) => { not_impl_err!("Arrays with unspecified type is not supported") @@ -392,26 +396,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - fn convert_simple_data_type(&self, sql_type: &SQLDataType) -> Result { + fn convert_simple_data_type(&self, sql_type: &SQLDataType) -> Result { match sql_type { - SQLDataType::Boolean | SQLDataType::Bool => Ok(DataType::Boolean), - SQLDataType::TinyInt(_) => Ok(DataType::Int8), - SQLDataType::SmallInt(_) | SQLDataType::Int2(_) => Ok(DataType::Int16), - SQLDataType::Int(_) | SQLDataType::Integer(_) | SQLDataType::Int4(_) => Ok(DataType::Int32), - SQLDataType::BigInt(_) | SQLDataType::Int8(_) => Ok(DataType::Int64), - SQLDataType::UnsignedTinyInt(_) => Ok(DataType::UInt8), - SQLDataType::UnsignedSmallInt(_) | SQLDataType::UnsignedInt2(_) => Ok(DataType::UInt16), + SQLDataType::Boolean | SQLDataType::Bool => Ok(LogicalType::Boolean), + SQLDataType::TinyInt(_) => Ok(LogicalType::Int8), + SQLDataType::SmallInt(_) | SQLDataType::Int2(_) => Ok(LogicalType::Int16), + SQLDataType::Int(_) | SQLDataType::Integer(_) | SQLDataType::Int4(_) => Ok(LogicalType::Int32), + SQLDataType::BigInt(_) | SQLDataType::Int8(_) => Ok(LogicalType::Int64), + SQLDataType::UnsignedTinyInt(_) => Ok(LogicalType::UInt8), + SQLDataType::UnsignedSmallInt(_) | SQLDataType::UnsignedInt2(_) => Ok(LogicalType::UInt16), SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) | SQLDataType::UnsignedInt4(_) => { - Ok(DataType::UInt32) + Ok(LogicalType::UInt32) } - SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(DataType::UInt64), - SQLDataType::Float(_) => Ok(DataType::Float32), - SQLDataType::Real | SQLDataType::Float4 => Ok(DataType::Float32), - SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(DataType::Float64), + SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(LogicalType::UInt64), + SQLDataType::Float(_) => Ok(LogicalType::Float32), + SQLDataType::Real | SQLDataType::Float4 => Ok(LogicalType::Float32), + SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(LogicalType::Float64), SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text - | SQLDataType::String(_) => Ok(DataType::Utf8), + | SQLDataType::String(_) => Ok(LogicalType::Utf8), SQLDataType::Timestamp(None, tz_info) => { let tz = if matches!(tz_info, TimezoneInfo::Tz) || matches!(tz_info, TimezoneInfo::WithTimeZone) @@ -424,14 +428,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Timestamp Without Time zone None }; - Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.map(Into::into))) + Ok(LogicalType::Timestamp(TimeUnit::Nanosecond, tz.map(Into::into))) } - SQLDataType::Date => Ok(DataType::Date32), + SQLDataType::Date => Ok(LogicalType::Date32), SQLDataType::Time(None, tz_info) => { if matches!(tz_info, TimezoneInfo::None) || matches!(tz_info, TimezoneInfo::WithoutTimeZone) { - Ok(DataType::Time64(TimeUnit::Nanosecond)) + Ok(LogicalType::Time64(TimeUnit::Nanosecond)) } else { // We dont support TIMETZ and TIME WITH TIME ZONE for now not_impl_err!( @@ -450,8 +454,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; make_decimal_type(precision, scale) } - SQLDataType::Bytea => Ok(DataType::Binary), - SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + SQLDataType::Bytea => Ok(LogicalType::Binary), + SQLDataType::Interval => Ok(LogicalType::Interval(IntervalUnit::MonthDayNano)), SQLDataType::Struct(fields) => { let fields = fields .iter() @@ -462,14 +466,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some(ident) => ident.clone(), None => Ident::new(format!("c{idx}")) }; - Ok(Arc::new(Field::new( + Ok(Arc::new(LogicalField::new( self.normalizer.normalize(field_name), data_type, true, ))) }) .collect::>>()?; - Ok(DataType::Struct(Fields::from(fields))) + Ok(LogicalType::Struct(LogicalFields::from(fields))) } // Explicitly list all other types so that if sqlparser // adds/changes the `SQLDataType` the compiler will tell us on upgrade diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index ee2e35b550f6..8978cb10b768 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -122,7 +122,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build() } JoinConstraint::Natural => { - let left_cols: HashSet<&String> = + let left_cols: HashSet<&str> = left.schema().fields().iter().map(|f| f.name()).collect(); let keys: Vec = right .schema() diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 518972545a48..4e61f38b4abf 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -29,7 +29,6 @@ use crate::planner::{ }; use crate::utils::normalize_ident; -use arrow_schema::{DataType, Fields}; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ exec_err, not_impl_err, plan_datafusion_err, plan_err, schema_err, @@ -60,6 +59,9 @@ use sqlparser::ast::{ TableConstraint, TableFactor, TableWithJoins, TransactionMode, UnaryOperator, Value, }; use sqlparser::parser::ParserError::ParserError; +use datafusion_common::logical_type::fields::LogicalFields; +use datafusion_common::logical_type::LogicalType; +use datafusion_common::logical_type::schema::LogicalSchema; fn ident_to_string(ident: &Ident) -> String { normalize_ident(ident.to_owned()) @@ -453,7 +455,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { statement, } => { // Convert parser data types to DataFusion data types - let data_types: Vec = data_types + let data_types: Vec = data_types .into_iter() .map(|t| self.convert_data_type(&t)) .collect::>()?; @@ -1222,7 +1224,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; let table_source = self.context_provider.get_table_source(table_ref.clone())?; - let schema = (*table_source.schema()).clone(); + let schema: LogicalSchema = (*table_source.schema()).clone().into(); let schema = DFSchema::try_from(schema)?; let scan = LogicalPlanBuilder::scan( object_name_to_string(&table_name), @@ -1275,7 +1277,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let table_source = self.context_provider.get_table_source(table_name.clone())?; let table_schema = Arc::new(DFSchema::try_from_qualified_schema( table_name.clone(), - &table_source.schema(), + &table_source.schema().as_ref().clone().into(), )?); // Overwrite with assignment expressions @@ -1380,7 +1382,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; let table_source = self.context_provider.get_table_source(table_name.clone())?; - let arrow_schema = (*table_source.schema()).clone(); + let arrow_schema: LogicalSchema = (*table_source.schema()).clone().into(); let table_schema = DFSchema::try_from(arrow_schema)?; // Get insert fields and target table's value indices @@ -1418,7 +1420,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(table_schema.field(column_index).clone()) }) .collect::>>()?; - (Fields::from(fields), value_indices) + (LogicalFields::from(fields), value_indices) }; // infer types for Values clause... other types should be resolvable the regular way diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index ad898de5987a..33186312ac0c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -27,7 +27,6 @@ use arrow_array::types::{ TimestampNanosecondType, TimestampSecondType, }; use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; -use arrow_schema::DataType; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ self, Expr as AstExpr, Function, FunctionArg, Ident, Interval, TimezoneInfo, @@ -38,6 +37,7 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, Result, ScalarValue, }; +use datafusion_common::logical_type::LogicalType; use datafusion_expr::{ expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, @@ -957,85 +957,67 @@ impl Unparser<'_> { } } - fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { + fn arrow_dtype_to_ast_dtype(&self, data_type: &LogicalType) -> Result { match data_type { - DataType::Null => { + LogicalType::Null => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Boolean => Ok(ast::DataType::Bool), - DataType::Int8 => Ok(ast::DataType::TinyInt(None)), - DataType::Int16 => Ok(ast::DataType::SmallInt(None)), - DataType::Int32 => Ok(ast::DataType::Integer(None)), - DataType::Int64 => Ok(ast::DataType::BigInt(None)), - DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)), - DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)), - DataType::UInt32 => Ok(ast::DataType::UnsignedInteger(None)), - DataType::UInt64 => Ok(ast::DataType::UnsignedBigInt(None)), - DataType::Float16 => { + LogicalType::Boolean => Ok(ast::DataType::Bool), + LogicalType::Int8 => Ok(ast::DataType::TinyInt(None)), + LogicalType::Int16 => Ok(ast::DataType::SmallInt(None)), + LogicalType::Int32 => Ok(ast::DataType::Integer(None)), + LogicalType::Int64 => Ok(ast::DataType::BigInt(None)), + LogicalType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)), + LogicalType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)), + LogicalType::UInt32 => Ok(ast::DataType::UnsignedInteger(None)), + LogicalType::UInt64 => Ok(ast::DataType::UnsignedBigInt(None)), + LogicalType::Float16 => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Float32 => Ok(ast::DataType::Float(None)), - DataType::Float64 => Ok(ast::DataType::Double), - DataType::Timestamp(_, _) => { + LogicalType::Float32 => Ok(ast::DataType::Float(None)), + LogicalType::Float64 => Ok(ast::DataType::Double), + LogicalType::Timestamp(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Date32 => Ok(ast::DataType::Date), - DataType::Date64 => Ok(ast::DataType::Datetime(None)), - DataType::Time32(_) => { + LogicalType::Date32 => Ok(ast::DataType::Date), + LogicalType::Date64 => Ok(ast::DataType::Datetime(None)), + LogicalType::Time32(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Time64(_) => { + LogicalType::Time64(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Duration(_) => { + LogicalType::Duration(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Interval(_) => { + LogicalType::Interval(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Binary => { + LogicalType::Binary => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::FixedSizeBinary(_) => { + LogicalType::FixedSizeBinary(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::LargeBinary => { + LogicalType::LargeBinary => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::BinaryView => { + LogicalType::Utf8 => Ok(ast::DataType::Varchar(None)), + LogicalType::LargeUtf8 => Ok(ast::DataType::Text), + LogicalType::List(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Utf8 => Ok(ast::DataType::Varchar(None)), - DataType::LargeUtf8 => Ok(ast::DataType::Text), - DataType::Utf8View => { + LogicalType::FixedSizeList(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::List(_) => { + LogicalType::LargeList(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::FixedSizeList(_, _) => { + LogicalType::Struct(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::LargeList(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } - DataType::ListView(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } - DataType::LargeListView(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } - DataType::Struct(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } - DataType::Union(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } - DataType::Dictionary(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } - DataType::Decimal128(precision, scale) - | DataType::Decimal256(precision, scale) => { + LogicalType::Decimal128(precision, scale) + | LogicalType::Decimal256(precision, scale) => { let mut new_precision = *precision as u64; let mut new_scale = *scale as u64; if *scale < 0 { @@ -1047,10 +1029,10 @@ impl Unparser<'_> { ast::ExactNumberInfo::PrecisionAndScale(new_precision, new_scale), )) } - DataType::Map(_, _) => { + LogicalType::Map(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::RunEndEncoded(_, _) => { + LogicalType::Extension(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } } @@ -1063,8 +1045,7 @@ mod tests { use std::{any::Any, sync::Arc, vec}; use arrow::datatypes::{Field, Schema}; - use arrow_schema::DataType::Int8; - + use arrow_schema::DataType; use datafusion_common::TableReference; use datafusion_expr::{ case, col, cube, exists, grouping_set, interval_datetime_lit, @@ -1153,14 +1134,14 @@ mod tests { ( Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Date64, + data_type: LogicalType::Date64, }), r#"CAST(a AS DATETIME)"#, ), ( Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::UInt32, + data_type: LogicalType::UInt32, }), r#"CAST(a AS INTEGER UNSIGNED)"#, ), @@ -1386,27 +1367,27 @@ mod tests { r#"NOT EXISTS (SELECT t.a FROM t WHERE (t.a = 1))"#, ), ( - try_cast(col("a"), DataType::Date64), + try_cast(col("a"), LogicalType::Date64), r#"TRY_CAST(a AS DATETIME)"#, ), ( - try_cast(col("a"), DataType::UInt32), + try_cast(col("a"), LogicalType::UInt32), r#"TRY_CAST(a AS INTEGER UNSIGNED)"#, ), ( - Expr::ScalarVariable(Int8, vec![String::from("@a")]), + Expr::ScalarVariable(LogicalType::Int8, vec![String::from("@a")]), r#"@a"#, ), ( Expr::ScalarVariable( - Int8, + LogicalType::Int8, vec![String::from("@root"), String::from("foo")], ), r#"@root.foo"#, ), (col("x").eq(placeholder("$1")), r#"(x = $1)"#), ( - out_ref_col(DataType::Int32, "t.a").gt(lit(1)), + out_ref_col(LogicalType::Int32, "t.a").gt(lit(1)), r#"(t.a > 1)"#, ), ( @@ -1481,7 +1462,7 @@ mod tests { ( Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Decimal128(10, -2), + data_type: LogicalType::Decimal128(10, -2), }), r#"CAST(a AS DECIMAL(12,0))"#, ), diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index bc27d25cf216..fbe8cd099f2a 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -20,7 +20,7 @@ use std::collections::HashMap; use arrow_schema::{ - DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ @@ -31,6 +31,7 @@ use datafusion_expr::expr::{Alias, GroupingSet, Unnest, WindowFunction}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{expr_vec_fmt, Expr, ExprSchemable, LogicalPlan}; use sqlparser::ast::Ident; +use datafusion_common::logical_type::LogicalType; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { @@ -226,7 +227,7 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr pub(crate) fn make_decimal_type( precision: Option, scale: Option, -) -> Result { +) -> Result { // postgres like behavior let (precision, scale) = match (precision, scale) { (Some(p), Some(s)) => (p as u8, s as i8), @@ -247,9 +248,9 @@ pub(crate) fn make_decimal_type( } else if precision > DECIMAL128_MAX_PRECISION && precision <= DECIMAL256_MAX_PRECISION { - Ok(DataType::Decimal256(precision, scale)) + Ok(LogicalType::Decimal256(precision, scale)) } else { - Ok(DataType::Decimal128(precision, scale)) + Ok(LogicalType::Decimal128(precision, scale)) } } @@ -316,7 +317,7 @@ pub(crate) fn recursive_transform_unnest( } = original_expr.transform_up(|expr: Expr| { if let Expr::Unnest(Unnest { expr: ref arg }) = expr { let (data_type, _) = arg.data_type_and_nullable(input.schema())?; - if let DataType::Struct(_) = data_type { + if let LogicalType::Struct(_) = data_type { return internal_err!("unnest on struct can ony be applied at the root level of select expression"); } let transformed_exprs = transform(&expr, arg)?; diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 374403d853f9..eefa9e2b3471 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -30,7 +30,6 @@ use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; - use crate::common::MockContextProvider; #[test] diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index f5caaefb3ea0..d7124d5e0180 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -25,6 +25,7 @@ use arrow_schema::*; use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{plan_err, GetExt, Result, TableReference}; +use datafusion_common::logical_type::LogicalType; use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; use datafusion_sql::planner::ContextProvider; @@ -202,7 +203,7 @@ impl ContextProvider for MockContextProvider { self.udafs.get(name).cloned() } - fn get_variable_type(&self, _: &[String]) -> Option { + fn get_variable_type(&self, _: &[String]) -> Option { unimplemented!() } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index f196d71d41de..777c935e2bfd 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -43,6 +43,7 @@ use datafusion_functions_aggregate::{ }; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; +use datafusion_common::logical_type::LogicalType; mod cases; mod common; @@ -3660,8 +3661,8 @@ fn test_prepare_statement_should_infer_types() { let plan = logical_plan(sql).unwrap(); let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int64)), + ("$1".to_string(), Some(LogicalType::Int32)), + ("$2".to_string(), Some(LogicalType::Int64)), ]); assert_eq!(actual_types, expected_types); } @@ -3674,7 +3675,7 @@ fn test_non_prepare_statement_should_infer_types() { let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ // constant 1 is inferred to be int64 - ("$1".to_string(), Some(DataType::Int64)), + ("$1".to_string(), Some(LogicalType::Int64)), ]); assert_eq!(actual_types, expected_types); } @@ -3849,7 +3850,7 @@ Projection: person.id, orders.order_id let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); + let expected_types = HashMap::from([("$1".to_string(), Some(LogicalType::Int32))]); assert_eq!(actual_types, expected_types); // replace params with values @@ -3881,7 +3882,7 @@ Projection: person.id, person.age let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); + let expected_types = HashMap::from([("$1".to_string(), Some(LogicalType::Int32))]); assert_eq!(actual_types, expected_types); // replace params with values @@ -3913,8 +3914,8 @@ Projection: person.id, person.age let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int32)), + ("$1".to_string(), Some(LogicalType::Int32)), + ("$2".to_string(), Some(LogicalType::Int32)), ]); assert_eq!(actual_types, expected_types); @@ -3952,7 +3953,7 @@ Projection: person.id, person.age let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]); + let expected_types = HashMap::from([("$1".to_string(), Some(LogicalType::UInt32))]); assert_eq!(actual_types, expected_types); // replace params with values @@ -3990,8 +3991,8 @@ Dml: op=[Update] table=[person] let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::UInt32)), + ("$1".to_string(), Some(LogicalType::Int32)), + ("$2".to_string(), Some(LogicalType::UInt32)), ]); assert_eq!(actual_types, expected_types); @@ -4025,9 +4026,9 @@ fn test_prepare_statement_insert_infer() { let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::UInt32)), - ("$2".to_string(), Some(DataType::Utf8)), - ("$3".to_string(), Some(DataType::Utf8)), + ("$1".to_string(), Some(LogicalType::UInt32)), + ("$2".to_string(), Some(LogicalType::Utf8)), + ("$3".to_string(), Some(LogicalType::Utf8)), ]); assert_eq!(actual_types, expected_types); diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index 520b6b53b32d..3dfcb34d1a3e 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::Fields; use arrow::util::display::ArrayFormatter; use arrow::{array, array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::logical_type::fields::LogicalFields; +use datafusion_common::logical_type::LogicalType; use datafusion_common::DataFusionError; use std::path::PathBuf; use std::sync::OnceLock; @@ -243,31 +244,31 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { } /// Converts columns to a result as expected by sqllogicteset. -pub(crate) fn convert_schema_to_types(columns: &Fields) -> Vec { +pub(crate) fn convert_schema_to_types(columns: &LogicalFields) -> Vec { columns .iter() .map(|f| f.data_type()) .map(|data_type| match data_type { - DataType::Boolean => DFColumnType::Boolean, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => DFColumnType::Integer, - DataType::Float16 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) => DFColumnType::Float, - DataType::Utf8 | DataType::LargeUtf8 => DFColumnType::Text, - DataType::Date32 - | DataType::Date64 - | DataType::Time32(_) - | DataType::Time64(_) => DFColumnType::DateTime, - DataType::Timestamp(_, _) => DFColumnType::Timestamp, + LogicalType::Boolean => DFColumnType::Boolean, + LogicalType::Int8 + | LogicalType::Int16 + | LogicalType::Int32 + | LogicalType::Int64 + | LogicalType::UInt8 + | LogicalType::UInt16 + | LogicalType::UInt32 + | LogicalType::UInt64 => DFColumnType::Integer, + LogicalType::Float16 + | LogicalType::Float32 + | LogicalType::Float64 + | LogicalType::Decimal128(_, _) + | LogicalType::Decimal256(_, _) => DFColumnType::Float, + LogicalType::Utf8 | LogicalType::LargeUtf8 => DFColumnType::Text, + LogicalType::Date32 + | LogicalType::Date64 + | LogicalType::Time32(_) + | LogicalType::Time64(_) => DFColumnType::DateTime, + LogicalType::Timestamp(_, _) => DFColumnType::Timestamp, _ => DFColumnType::Another, }) .collect() diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9bc842a12af4..56f659f1d870 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -17,8 +17,13 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{ - DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, + Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; +use datafusion::common::logical_type::extension::ExtensionType; +use datafusion::common::logical_type::field::{LogicalField, LogicalFieldRef}; +use datafusion::common::logical_type::fields::LogicalFields; +use datafusion::common::logical_type::schema::LogicalSchema; +use datafusion::common::logical_type::LogicalType; use datafusion::common::{ not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, }; @@ -349,12 +354,12 @@ fn make_renamed_schema( dfs_names: &Vec, ) -> Result { fn rename_inner_fields( - dtype: &DataType, + dtype: &LogicalType, dfs_names: &Vec, name_idx: &mut usize, - ) -> Result { + ) -> Result { match dtype { - DataType::Struct(fields) => { + LogicalType::Struct(fields) => { let fields = fields .iter() .map(|f| { @@ -364,16 +369,16 @@ fn make_renamed_schema( )) }) .collect::>()?; - Ok(DataType::Struct(fields)) + Ok(LogicalType::Struct(fields)) } - DataType::List(inner) => Ok(DataType::List(FieldRef::new( + LogicalType::List(inner) => Ok(LogicalType::List(LogicalFieldRef::new( (**inner).to_owned().with_data_type(rename_inner_fields( inner.data_type(), dfs_names, name_idx, )?), ))), - DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new( + LogicalType::LargeList(inner) => Ok(LogicalType::LargeList(LogicalFieldRef::new( (**inner).to_owned().with_data_type(rename_inner_fields( inner.data_type(), dfs_names, @@ -386,7 +391,7 @@ fn make_renamed_schema( let mut name_idx = 0; - let (qualifiers, fields): (_, Vec) = schema + let (qualifiers, fields): (_, Vec) = schema .iter() .map(|(q, f)| { let name = next_struct_field_name(0, dfs_names, &mut name_idx)?; @@ -415,7 +420,7 @@ fn make_renamed_schema( DFSchema::from_field_specific_qualified_schema( qualifiers, - &Arc::new(Schema::new(fields)), + &Arc::new(LogicalSchema::new(fields)), ) } @@ -863,7 +868,7 @@ pub async fn from_substrait_rel( } /// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise -/// conflict with the columns from the other. +/// conflict with the columns from the other. /// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For /// Substrait the names don't matter since it only refers to columns by indices, however DataFusion /// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names). @@ -1348,7 +1353,7 @@ pub async fn from_substrait_rex( } } -pub(crate) fn from_substrait_type_without_names(dt: &Type) -> Result { +pub(crate) fn from_substrait_type_without_names(dt: &Type) -> Result { from_substrait_type(dt, &[], &mut 0) } @@ -1356,77 +1361,77 @@ fn from_substrait_type( dt: &Type, dfs_names: &[String], name_idx: &mut usize, -) -> Result { +) -> Result { match &dt.kind { Some(s_kind) => match s_kind { - r#type::Kind::Bool(_) => Ok(DataType::Boolean), + r#type::Kind::Bool(_) => Ok(LogicalType::Boolean), r#type::Kind::I8(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int8), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt8), + DEFAULT_TYPE_VARIATION_REF => Ok(LogicalType::Int8), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(LogicalType::UInt8), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::I16(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int16), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt16), + DEFAULT_TYPE_VARIATION_REF => Ok(LogicalType::Int16), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(LogicalType::UInt16), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::I32(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt32), + DEFAULT_TYPE_VARIATION_REF => Ok(LogicalType::Int32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(LogicalType::UInt32), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::I64(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int64), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt64), + DEFAULT_TYPE_VARIATION_REF => Ok(LogicalType::Int64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(LogicalType::UInt64), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, - r#type::Kind::Fp32(_) => Ok(DataType::Float32), - r#type::Kind::Fp64(_) => Ok(DataType::Float64), + r#type::Kind::Fp32(_) => Ok(LogicalType::Float32), + r#type::Kind::Fp64(_) => Ok(LogicalType::Float64), r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) + Ok(LogicalType::Timestamp(TimeUnit::Second, None)) } TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) + Ok(LogicalType::Timestamp(TimeUnit::Millisecond, None)) } TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) + Ok(LogicalType::Timestamp(TimeUnit::Microsecond, None)) } TIMESTAMP_NANO_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + Ok(LogicalType::Timestamp(TimeUnit::Nanosecond, None)) } v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::Date(date) => match date.type_variation_reference { - DATE_32_TYPE_VARIATION_REF => Ok(DataType::Date32), - DATE_64_TYPE_VARIATION_REF => Ok(DataType::Date64), + DATE_32_TYPE_VARIATION_REF => Ok(LogicalType::Date32), + DATE_64_TYPE_VARIATION_REF => Ok(LogicalType::Date64), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::Binary(binary) => match binary.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(LogicalType::Binary), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(LogicalType::LargeBinary), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::FixedBinary(fixed) => { - Ok(DataType::FixedSizeBinary(fixed.length)) + Ok(LogicalType::FixedSizeBinary(fixed.length)) } r#type::Kind::String(string) => match string.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(LogicalType::Utf8), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(LogicalType::LargeUtf8), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), @@ -1435,15 +1440,15 @@ fn from_substrait_type( let inner_type = list.r#type.as_ref().ok_or_else(|| { substrait_datafusion_err!("List type must have inner type") })?; - let field = Arc::new(Field::new_list_field( + let field = Arc::new(LogicalField::new_list_field( from_substrait_type(inner_type, dfs_names, name_idx)?, - // We ignore Substrait's nullability here to match to_substrait_literal + // We ignore Substrait's nullability here to match to_substrait_literal // which always creates nullable lists true, )); match list.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeList(field)), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(LogicalType::List(field)), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(LogicalType::LargeList(field)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" )?, @@ -1456,19 +1461,19 @@ fn from_substrait_type( let value_type = map.value.as_ref().ok_or_else(|| { substrait_datafusion_err!("Map type must have value type") })?; - let key_field = Arc::new(Field::new( + let key_field = Arc::new(LogicalField::new( "key", from_substrait_type(key_type, dfs_names, name_idx)?, false, )); - let value_field = Arc::new(Field::new( + let value_field = Arc::new(LogicalField::new( "value", from_substrait_type(value_type, dfs_names, name_idx)?, true, )); match map.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - Ok(DataType::Map(Arc::new(Field::new_struct( + Ok(LogicalType::Map(Arc::new(LogicalField::new_struct( "entries", [key_field, value_field], false, // The inner map field is always non-nullable (Arrow #1697), @@ -1481,10 +1486,10 @@ fn from_substrait_type( } r#type::Kind::Decimal(d) => match d.type_variation_reference { DECIMAL_128_TYPE_VARIATION_REF => { - Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) + Ok(LogicalType::Decimal128(d.precision as u8, d.scale as i8)) } DECIMAL_256_TYPE_VARIATION_REF => { - Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) + Ok(LogicalType::Decimal256(d.precision as u8, d.scale as i8)) } v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" @@ -1493,13 +1498,13 @@ fn from_substrait_type( r#type::Kind::UserDefined(u) => { match u.type_reference { INTERVAL_YEAR_MONTH_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) + Ok(LogicalType::Interval(IntervalUnit::YearMonth)) } INTERVAL_DAY_TIME_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::DayTime)) + Ok(LogicalType::Interval(IntervalUnit::DayTime)) } INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + Ok(LogicalType::Interval(IntervalUnit::MonthDayNano)) } _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", @@ -1508,11 +1513,11 @@ fn from_substrait_type( ), } }, - r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( + r#type::Kind::Struct(s) => Ok(LogicalType::Struct(from_substrait_struct_type( s, dfs_names, name_idx, )?)), - r#type::Kind::Varchar(_) => Ok(DataType::Utf8), - r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), + r#type::Kind::Varchar(_) => Ok(LogicalType::Utf8), + r#type::Kind::FixedChar(_) => Ok(LogicalType::Utf8), _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), }, _ => not_impl_err!("`None` Substrait kind is not supported"), @@ -1523,10 +1528,10 @@ fn from_substrait_struct_type( s: &r#type::Struct, dfs_names: &[String], name_idx: &mut usize, -) -> Result { +) -> Result { let mut fields = vec![]; for (i, f) in s.types.iter().enumerate() { - let field = Field::new( + let field = LogicalField::new( next_struct_field_name(i, dfs_names, name_idx)?, from_substrait_type(f, dfs_names, name_idx)?, is_substrait_type_nullable(f)?, @@ -1778,10 +1783,10 @@ fn from_substrait_literal( )?; match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type)) + ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type.physical_type())) } LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( - ScalarValue::new_large_list(&[], &element_type), + ScalarValue::new_large_list(&[], &element_type.physical_type()), ), others => { return substrait_err!("Unknown type variation reference {others}"); @@ -1958,7 +1963,7 @@ fn from_substrait_null( d.scale as i8, )), r#type::Kind::List(l) => { - let field = Field::new_list_field( + let field = LogicalField::new_list_field( from_substrait_type( l.r#type.clone().unwrap().as_ref(), dfs_names, @@ -1968,10 +1973,10 @@ fn from_substrait_null( ); match l.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::List( - Arc::new(GenericListArray::new_null(field.into(), 1)), + Arc::new(GenericListArray::new_null(FieldRef::new(field.into()), 1)), )), LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeList( - Arc::new(GenericListArray::new_null(field.into(), 1)), + Arc::new(GenericListArray::new_null(FieldRef::new(field.into()), 1)), )), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" @@ -1979,7 +1984,7 @@ fn from_substrait_null( } } r#type::Kind::Struct(s) => { - let fields = from_substrait_struct_type(s, dfs_names, name_idx)?; + let fields: Fields = from_substrait_struct_type(s, dfs_names, name_idx)?.into(); Ok(ScalarStructBuilder::new_null(fields)) } _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 302f38606bfb..07c40f9a9fbd 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -26,14 +26,26 @@ use datafusion::logical_expr::{ CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, }; use datafusion::{ - arrow::datatypes::{DataType, TimeUnit}, + arrow::datatypes::TimeUnit, error::{DataFusionError, Result}, logical_expr::{WindowFrame, WindowFrameBound}, prelude::{JoinType, SessionContext}, scalar::ScalarValue, }; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_DAY_TIME_TYPE_URL, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_URL, + INTERVAL_YEAR_MONTH_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_URL, + LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, + TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, + TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, +}; use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; +use datafusion::common::logical_type::LogicalType; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, }; @@ -91,18 +103,6 @@ use substrait::{ version, }; -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_DAY_TIME_TYPE_URL, - INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_URL, - INTERVAL_YEAR_MONTH_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_URL, - LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, - TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, - TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, -}; - /// Convert DataFusion LogicalPlan to Substrait Plan pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result> { // Parse relation nodes @@ -586,9 +586,9 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { // Substrait wants a list of all field names, including nested fields from structs, // also from within e.g. lists and maps. However, it does not want the list and map field names // themselves - only proper structs fields are considered to have useful names. - fn names_dfs(dtype: &DataType) -> Result> { + fn names_dfs(dtype: &LogicalType) -> Result> { match dtype { - DataType::Struct(fields) => { + LogicalType::Struct(fields) => { let mut names = Vec::new(); for field in fields { names.push(field.name().to_string()); @@ -596,10 +596,10 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { } Ok(names) } - DataType::List(l) => names_dfs(l.data_type()), - DataType::LargeList(l) => names_dfs(l.data_type()), - DataType::Map(m, _) => match m.data_type() { - DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + LogicalType::List(l) => names_dfs(l.data_type()), + LogicalType::LargeList(l) => names_dfs(l.data_type()), + LogicalType::Map(m, _) => match m.data_type() { + LogicalType::Struct(key_and_value) if key_and_value.len() == 2 => { let key_names = names_dfs(key_and_value.first().unwrap().data_type())?; let value_names = @@ -1433,83 +1433,83 @@ pub fn to_substrait_rex( } } -fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { +fn to_substrait_type(dt: &LogicalType, nullable: bool) -> Result { let nullability = if nullable { r#type::Nullability::Nullable as i32 } else { r#type::Nullability::Required as i32 }; match dt { - DataType::Null => internal_err!("Null cast is not valid"), - DataType::Boolean => Ok(substrait::proto::Type { + LogicalType::Null => internal_err!("Null cast is not valid"), + LogicalType::Boolean => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Bool(r#type::Boolean { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), - DataType::Int8 => Ok(substrait::proto::Type { + LogicalType::Int8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), - DataType::UInt8 => Ok(substrait::proto::Type { + LogicalType::UInt8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, nullability, })), }), - DataType::Int16 => Ok(substrait::proto::Type { + LogicalType::Int16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), - DataType::UInt16 => Ok(substrait::proto::Type { + LogicalType::UInt16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, nullability, })), }), - DataType::Int32 => Ok(substrait::proto::Type { + LogicalType::Int32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), - DataType::UInt32 => Ok(substrait::proto::Type { + LogicalType::UInt32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, nullability, })), }), - DataType::Int64 => Ok(substrait::proto::Type { + LogicalType::Int64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), - DataType::UInt64 => Ok(substrait::proto::Type { + LogicalType::UInt64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, nullability, })), }), // Float16 is not supported in Substrait - DataType::Float32 => Ok(substrait::proto::Type { + LogicalType::Float32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Fp32(r#type::Fp32 { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), - DataType::Float64 => Ok(substrait::proto::Type { + LogicalType::Float64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Fp64(r#type::Fp64 { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), // Timezone is ignored. - DataType::Timestamp(unit, _) => { + LogicalType::Timestamp(unit, _) => { let type_variation_reference = match unit { TimeUnit::Second => TIMESTAMP_SECOND_TYPE_VARIATION_REF, TimeUnit::Millisecond => TIMESTAMP_MILLI_TYPE_VARIATION_REF, @@ -1523,19 +1523,19 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Ok(substrait::proto::Type { + LogicalType::Date32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { type_variation_reference: DATE_32_TYPE_VARIATION_REF, nullability, })), }), - DataType::Date64 => Ok(substrait::proto::Type { + LogicalType::Date64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { type_variation_reference: DATE_64_TYPE_VARIATION_REF, nullability, })), }), - DataType::Interval(interval_unit) => { + LogicalType::Interval(interval_unit) => { // define two type parameters for convenience let i32_param = Parameter { parameter: Some(parameter::Parameter::DataType(substrait::proto::Type { @@ -1578,38 +1578,38 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Ok(substrait::proto::Type { + LogicalType::Binary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, nullability, })), }), - DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { + LogicalType::FixedSizeBinary(length) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { length: *length, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), - DataType::LargeBinary => Ok(substrait::proto::Type { + LogicalType::LargeBinary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, nullability, })), }), - DataType::Utf8 => Ok(substrait::proto::Type { + LogicalType::Utf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, nullability, })), }), - DataType::LargeUtf8 => Ok(substrait::proto::Type { + LogicalType::LargeUtf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, nullability, })), }), - DataType::List(inner) => { + LogicalType::List(inner) => { let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { @@ -1619,7 +1619,7 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { + LogicalType::LargeList(inner) => { let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { @@ -1629,8 +1629,8 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result match inner.data_type() { - DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + LogicalType::Map(inner, _) => match inner.data_type() { + LogicalType::Struct(key_and_value) if key_and_value.len() == 2 => { let key_type = to_substrait_type( key_and_value[0].data_type(), key_and_value[0].is_nullable(), @@ -1650,7 +1650,7 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result plan_err!("Map fields must contain a Struct with exactly 2 fields"), }, - DataType::Struct(fields) => { + LogicalType::Struct(fields) => { let field_types = fields .iter() .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) @@ -1663,7 +1663,7 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Ok(substrait::proto::Type { + LogicalType::Decimal128(p, s) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { type_variation_reference: DECIMAL_128_TYPE_VARIATION_REF, nullability, @@ -1671,7 +1671,7 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Ok(substrait::proto::Type { + LogicalType::Decimal256(p, s) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { type_variation_reference: DECIMAL_256_TYPE_VARIATION_REF, nullability, @@ -1897,7 +1897,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { nullable: true, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, literal_type: Some(LiteralType::Null(to_substrait_type( - &value.data_type(), + &value.data_type().into(), true, )?)), }); @@ -2097,7 +2097,7 @@ fn convert_array_to_literal_list( .collect::>>()?; if values.is_empty() { - let et = match to_substrait_type(array.data_type(), array.is_nullable())? { + let et = match to_substrait_type(&array.data_type().to_owned().into(), array.is_nullable())? { substrait::proto::Type { kind: Some(r#type::Kind::List(lt)), } => lt.as_ref().to_owned(), @@ -2223,7 +2223,7 @@ mod test { from_substrait_literal_without_names, from_substrait_type_without_names, }; use datafusion::arrow::array::GenericListArray; - use datafusion::arrow::datatypes::Field; + use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::scalar::ScalarStructBuilder; use super::*; @@ -2376,9 +2376,10 @@ mod test { // As DataFusion doesn't consider nullability as a property of the type, but field, // it doesn't matter if we set nullability to true or false here. - let substrait = to_substrait_type(&dt, true)?; + let lt = dt.into(); + let substrait = to_substrait_type(<, true)?; let roundtrip_dt = from_substrait_type_without_names(&substrait)?; - assert_eq!(dt, roundtrip_dt); + assert_eq!(lt, roundtrip_dt); Ok(()) } }