diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 0550741a9f9..9098285b482 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -37,8 +37,7 @@ use crate::{ }; use crate::{ physical_plan::{ - aggregates, expressions::binary_operator_data_type, functions, - type_coercion::can_coerce_from, udf::ScalarUDF, + aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, }, sql::parser::FileType, }; @@ -323,21 +322,19 @@ impl Expr { /// /// # Errors /// - /// This function errors when it is impossible to cast the expression to the target [arrow::datatypes::DataType]. + /// Currently no errors happen at plan time. If it is impossible + /// to cast the expression to the target + /// [arrow::datatypes::DataType] then an error will occur at + /// runtime. pub fn cast_to(&self, cast_to_type: &DataType, schema: &Schema) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { Ok(self.clone()) - } else if can_coerce_from(cast_to_type, &this_type) { + } else { Ok(Expr::Cast { expr: Box::new(self.clone()), data_type: cast_to_type.clone(), }) - } else { - Err(ExecutionError::General(format!( - "Cannot automatically convert {:?} to {:?}", - this_type, cast_to_type - ))) } } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 5640daa5303..fd1a09826f5 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -21,8 +21,8 @@ use std::sync::Arc; extern crate arrow; extern crate datafusion; -use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::TimeUnit}; +use arrow::{datatypes::Int32Type, record_batch::RecordBatch}; use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, util::pretty::array_value_to_string, @@ -918,14 +918,20 @@ fn register_alltypes_parquet(ctx: &mut ExecutionContext) { /// Execute query and return result set as 2-d table of Vecs /// `result[row][column]` async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec> { - let plan = ctx.create_logical_plan(&sql).unwrap(); + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(&sql).expect(&msg); let logical_schema = plan.schema(); - let plan = ctx.optimize(&plan).unwrap(); + + let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); + let plan = ctx.optimize(&plan).expect(&msg); let optimized_logical_schema = plan.schema(); - let plan = ctx.create_physical_plan(&plan).unwrap(); + + let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); + let plan = ctx.create_physical_plan(&plan).expect(&msg); let physical_schema = plan.schema(); - let results = ctx.collect(plan).await.unwrap(); + let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); + let results = ctx.collect(plan).await.expect(&msg); assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); assert_eq!(logical_schema.as_ref(), physical_schema.as_ref()); @@ -1200,3 +1206,69 @@ async fn query_is_not_null() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn query_on_string_dictionary() -> Result<()> { + // Test to ensure DataFusion can operate on dictionary types + // Use StringDictionary (32 bit indexes = keys) + let field_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + + builder.append("one")?; + builder.append_null()?; + builder.append("three")?; + let array = Arc::new(builder.finish()); + + let data = RecordBatch::try_new(schema.clone(), vec![array])?; + + let table = MemTable::new(schema, vec![vec![data]])?; + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Box::new(table)); + + // Basic SELECT + let sql = "SELECT * FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["one"], vec!["NULL"], vec!["three"]]; + assert_eq!(expected, actual); + + // basic filtering + let sql = "SELECT * FROM test WHERE d1 IS NOT NULL"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["one"], vec!["three"]]; + assert_eq!(expected, actual); + + // The following queries are not yet supported + + // // filtering with constant + // let sql = "SELECT * FROM test WHERE d1 = 'three'"; + // let actual = execute(&mut ctx, sql).await; + // let expected = vec![ + // vec!["three"], + // ]; + // assert_eq!(expected, actual); + + // // Expression evaluation + // let sql = "SELECT concat(d1, '-foo') FROM test"; + // let actual = execute(&mut ctx, sql).await; + // let expected = vec![ + // vec!["one-foo"], + // vec!["NULL"], + // vec!["three-foo"], + // ]; + // assert_eq!(expected, actual); + + // // aggregation + // let sql = "SELECT COUNT(d1) FROM test"; + // let actual = execute(&mut ctx, sql).await; + // let expected = vec![ + // vec!["2"] + // ]; + // assert_eq!(expected, actual); + + Ok(()) +}