diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index 7b0c6bc9a86..70acf5a7445 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -72,6 +72,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8, (Utf8, Date32(DateUnit::Day)) => true, + (Utf8, Date64(DateUnit::Millisecond)) => true, (Utf8, _) => DataType::is_numeric(to_type), (_, Utf8) => DataType::is_numeric(from_type) || from_type == &Binary, @@ -399,6 +400,26 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { } Ok(Arc::new(builder.finish()) as ArrayRef) } + Date64(DateUnit::Millisecond) => { + use chrono::{NaiveDate, NaiveTime}; + let zero_time = NaiveTime::from_hms(0, 0, 0); + let string_array = array.as_any().downcast_ref::().unwrap(); + let mut builder = PrimitiveBuilder::::new(string_array.len()); + for i in 0..string_array.len() { + if string_array.is_null(i) { + builder.append_null()?; + } else { + match NaiveDate::parse_from_str(string_array.value(i), "%Y-%m-%d") + { + Ok(date) => builder.append_value( + date.and_time(zero_time).timestamp_millis() as i64, + )?, + Err(_) => builder.append_null()?, // not a valid date + }; + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } _ => Err(ArrowError::ComputeError(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, @@ -2780,6 +2801,31 @@ mod tests { assert_eq!(false, c.is_valid(4)); // "2000" } + #[test] + fn test_cast_utf8_to_date64() { + let a = StringArray::from(vec![ + "2000-01-01", // valid date with leading 0s + "2000-2-2", // valid date without leading 0s + "2000-00-00", // invalid month and day + "2000-01-01T12:00:00", // date + time is invalid + "2000", // just a year is invalid + ]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date64(DateUnit::Millisecond)).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + + // test valid inputs + assert_eq!(true, c.is_valid(0)); // "2000-01-01" + assert_eq!(946684800000, c.value(0)); + assert_eq!(true, c.is_valid(1)); // "2000-2-2" + assert_eq!(949449600000, c.value(1)); + + // test invalid inputs + assert_eq!(false, c.is_valid(2)); // "2000-00-00" + assert_eq!(false, c.is_valid(3)); // "2000-01-01T12:00:00" + assert_eq!(false, c.is_valid(4)); // "2000" + } + #[test] fn test_can_cast_types() { // this function attempts to ensure that can_cast_types stays diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index 2ed9ab0d8af..cd3d9d8adff 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -21,7 +21,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, DateUnit, Field, Schema}; use arrow::util::pretty; use datafusion::datasource::parquet::ParquetTable; use datafusion::datasource::{CsvFile, MemTable, TableProvider}; @@ -187,7 +187,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result Result '1995-03-15' + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' group by l_orderkey, o_orderdate, @@ -337,8 +337,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result= '1994-01-01' - and o_orderdate < '1995-01-01' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1995-01-01' group by n_name order by @@ -363,9 +363,9 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result= '1994-01-01' - and l_shipdate < '1995-01-01' - and l_discount between 0.06 - 0.01 and 0.06 + 0.01 + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1995-01-01' + and l_discount > 0.06 - 0.01 and l_discount < 0.06 + 0.01 and l_quantity < 24;" ), @@ -399,7 +399,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result '1995-01-01' and l_shipdate < '1996-12-31' + and l_shipdate > date '1995-01-01' and l_shipdate < date '1996-12-31' ) as shipping group by supp_nation, @@ -442,7 +442,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result Result ctx.create_logical_plan( + // "select + // c_custkey, + // c_name, + // sum(l_extendedprice * (1 - l_discount)) as revenue, + // c_acctbal, + // n_name, + // c_address, + // c_phone, + // c_comment + // from + // customer, + // orders, + // lineitem, + // nation + // where + // c_custkey = o_custkey + // and l_orderkey = o_orderkey + // and o_orderdate >= date '1993-10-01' + // and o_orderdate < date '1993-10-01' + interval '3' month + // and l_returnflag = 'R' + // and c_nationkey = n_nationkey + // group by + // c_custkey, + // c_name, + // c_acctbal, + // c_phone, + // n_name, + // c_address, + // c_comment + // order by + // revenue desc;" + // ), 10 => ctx.create_logical_plan( "select c_custkey, @@ -504,8 +537,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result= '1993-10-01' - and o_orderdate < '1994-01-01' + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1994-01-01' and l_returnflag = 'R' and c_nationkey = n_nationkey group by @@ -606,8 +639,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result= '1994-01-01' - and l_receiptdate < '1995-01-01' + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1995-01-01' group by l_shipmode order by @@ -649,8 +682,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result= '1995-09-01' - and l_shipdate < '1995-10-01';" + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-10-01';" ), 15 => ctx.create_logical_plan( @@ -1072,7 +1105,7 @@ fn get_schema(table: &str) -> Schema { Field::new("o_custkey", DataType::UInt32, false), Field::new("o_orderstatus", DataType::Utf8, false), Field::new("o_totalprice", DataType::Float64, false), // decimal - Field::new("o_orderdate", DataType::Utf8, false), + Field::new("o_orderdate", DataType::Date32(DateUnit::Day), false), Field::new("o_orderpriority", DataType::Utf8, false), Field::new("o_clerk", DataType::Utf8, false), Field::new("o_shippriority", DataType::UInt32, false), @@ -1090,9 +1123,9 @@ fn get_schema(table: &str) -> Schema { Field::new("l_tax", DataType::Float64, false), // decimal Field::new("l_returnflag", DataType::Utf8, false), Field::new("l_linestatus", DataType::Utf8, false), - Field::new("l_shipdate", DataType::Utf8, false), - Field::new("l_commitdate", DataType::Utf8, false), - Field::new("l_receiptdate", DataType::Utf8, false), + Field::new("l_shipdate", DataType::Date32(DateUnit::Day), false), + Field::new("l_commitdate", DataType::Date32(DateUnit::Day), false), + Field::new("l_receiptdate", DataType::Date32(DateUnit::Day), false), Field::new("l_shipinstruct", DataType::Utf8, false), Field::new("l_shipmode", DataType::Utf8, false), Field::new("l_comment", DataType::Utf8, false), diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index d9a00c8e290..d36579a1979 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -48,9 +48,9 @@ use arrow::datatypes::{DataType, DateUnit, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; use arrow::{ array::{ - ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, StringArray, TimestampNanosecondArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, StringArray, + TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, datatypes::Field, }; @@ -1135,6 +1135,9 @@ macro_rules! binary_array_op { DataType::Date32(DateUnit::Day) => { compute_op!($LEFT, $RIGHT, $OP, Date32Array) } + DataType::Date64(DateUnit::Millisecond) => { + compute_op!($LEFT, $RIGHT, $OP, Date64Array) + } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?}", other @@ -1227,6 +1230,19 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } } +/// Coercion rules for Temporal columns: the type that both lhs and rhs can be +/// casted to for the purpose of a date computation +fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Utf8, Date32(DateUnit::Day)) => Some(Date32(DateUnit::Day)), + (Date32(DateUnit::Day), Utf8) => Some(Date32(DateUnit::Day)), + (Utf8, Date64(DateUnit::Millisecond)) => Some(Date64(DateUnit::Millisecond)), + (Date64(DateUnit::Millisecond), Utf8) => Some(Date64(DateUnit::Millisecond)), + _ => None, + } +} + /// Coercion rule for numerical types: The type that both lhs and rhs /// can be casted to for numerical calculation, while maintaining /// maximum precision @@ -1288,6 +1304,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { } numerical_coercion(lhs_type, rhs_type) .or_else(|| dictionary_coercion(lhs_type, rhs_type)) + .or_else(|| temporal_coercion(lhs_type, rhs_type)) } // coercion rules that assume an ordered set, such as "less than". @@ -1301,6 +1318,7 @@ fn order_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option numerical_coercion(lhs_type, rhs_type) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| dictionary_coercion(lhs_type, rhs_type)) + .or_else(|| temporal_coercion(lhs_type, rhs_type)) } /// Coercion rules for all binary operators. Returns the output type @@ -2629,6 +2647,54 @@ mod tests { DataType::Boolean, vec![true, false] ); + test_coercion!( + StringArray, + DataType::Utf8, + vec!["1994-12-13", "1995-01-26"], + Date32Array, + DataType::Date32(DateUnit::Day), + vec![9112, 9156], + Operator::Eq, + BooleanArray, + DataType::Boolean, + vec![true, true] + ); + test_coercion!( + StringArray, + DataType::Utf8, + vec!["1994-12-13", "1995-01-26"], + Date32Array, + DataType::Date32(DateUnit::Day), + vec![9113, 9154], + Operator::Lt, + BooleanArray, + DataType::Boolean, + vec![true, false] + ); + test_coercion!( + StringArray, + DataType::Utf8, + vec!["1994-12-13", "1995-01-26"], + Date64Array, + DataType::Date64(DateUnit::Millisecond), + vec![787276800000, 791078400000], + Operator::Eq, + BooleanArray, + DataType::Boolean, + vec![true, true] + ); + test_coercion!( + StringArray, + DataType::Utf8, + vec!["1994-12-13", "1995-01-26"], + Date64Array, + DataType::Date64(DateUnit::Millisecond), + vec![787276800001, 791078399999], + Operator::Lt, + BooleanArray, + DataType::Boolean, + vec![true, false] + ); Ok(()) } diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index a807ed08f65..878551c7046 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -629,6 +629,14 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { data_type: convert_data_type(data_type)?, }), + SQLExpr::TypedString { + ref data_type, + ref value, + } => Ok(Expr::Cast { + expr: Box::new(lit(&**value)), + data_type: convert_data_type(data_type)?, + }), + SQLExpr::IsNull(ref expr) => { Ok(Expr::IsNull(Box::new(self.sql_to_rex(expr, schema)?))) } @@ -1311,6 +1319,14 @@ mod tests { quick_test(sql, expected); } + #[test] + fn select_typedstring() { + let sql = "SELECT date '2020-12-10' AS date FROM person"; + let expected = "Projection: CAST(Utf8(\"2020-12-10\") AS Date32(Day)) AS date\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + fn logical_plan(sql: &str) -> Result { let planner = SqlToRel::new(&MockSchemaProvider {}); let result = DFParser::parse_sql(&sql);