Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions rust/arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8,

(Utf8, Date32(DateUnit::Day)) => true,
(Utf8, Date64(DateUnit::Millisecond)) => true,
(Utf8, _) => DataType::is_numeric(to_type),
(_, Utf8) => DataType::is_numeric(from_type) || from_type == &Binary,

Expand Down Expand Up @@ -399,6 +400,26 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
Date64(DateUnit::Millisecond) => {
use chrono::{NaiveDate, NaiveTime};
let zero_time = NaiveTime::from_hms(0, 0, 0);
let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
let mut builder = PrimitiveBuilder::<Date64Type>::new(string_array.len());
for i in 0..string_array.len() {
if string_array.is_null(i) {
builder.append_null()?;
} else {
match NaiveDate::parse_from_str(string_array.value(i), "%Y-%m-%d")
{
Ok(date) => builder.append_value(
date.and_time(zero_time).timestamp_millis() as i64,
)?,
Err(_) => builder.append_null()?, // not a valid date
};
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
_ => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
Expand Down Expand Up @@ -2780,6 +2801,31 @@ mod tests {
assert_eq!(false, c.is_valid(4)); // "2000"
}

#[test]
fn test_cast_utf8_to_date64() {
let a = StringArray::from(vec![
"2000-01-01", // valid date with leading 0s
"2000-2-2", // valid date without leading 0s
"2000-00-00", // invalid month and day
"2000-01-01T12:00:00", // date + time is invalid
"2000", // just a year is invalid
]);
let array = Arc::new(a) as ArrayRef;
let b = cast(&array, &DataType::Date64(DateUnit::Millisecond)).unwrap();
let c = b.as_any().downcast_ref::<Date64Array>().unwrap();

// test valid inputs
assert_eq!(true, c.is_valid(0)); // "2000-01-01"
assert_eq!(946684800000, c.value(0));
assert_eq!(true, c.is_valid(1)); // "2000-2-2"
assert_eq!(949449600000, c.value(1));

// test invalid inputs
assert_eq!(false, c.is_valid(2)); // "2000-00-00"
assert_eq!(false, c.is_valid(3)); // "2000-01-01T12:00:00"
assert_eq!(false, c.is_valid(4)); // "2000"
}

#[test]
fn test_can_cast_types() {
// this function attempts to ensure that can_cast_types stays
Expand Down
75 changes: 54 additions & 21 deletions rust/benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;

use arrow::datatypes::{DataType, Field, Schema};
use arrow::datatypes::{DataType, DateUnit, Field, Schema};
use arrow::util::pretty;
use datafusion::datasource::parquet::ParquetTable;
use datafusion::datasource::{CsvFile, MemTable, TableProvider};
Expand Down Expand Up @@ -187,7 +187,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
from
lineitem
where
l_shipdate <= '1998-09-02'
l_shipdate <= date '1998-09-02'
group by
l_returnflag,
l_linestatus
Expand Down Expand Up @@ -256,8 +256,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
c_mktsegment = 'BUILDING'
and c_custkey = o_custkey
and l_orderkey = o_orderkey
and o_orderdate < '1995-03-15'
and l_shipdate > '1995-03-15'
and o_orderdate < date '1995-03-15'
and l_shipdate > date '1995-03-15'
group by
l_orderkey,
o_orderdate,
Expand Down Expand Up @@ -337,8 +337,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
and s_nationkey = n_nationkey
and n_regionkey = r_regionkey
and r_name = 'ASIA'
and o_orderdate >= '1994-01-01'
and o_orderdate < '1995-01-01'
and o_orderdate >= date '1994-01-01'
and o_orderdate < date '1995-01-01'
group by
n_name
order by
Expand All @@ -363,9 +363,9 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
from
lineitem
where
l_shipdate >= '1994-01-01'
and l_shipdate < '1995-01-01'
and l_discount between 0.06 - 0.01 and 0.06 + 0.01
l_shipdate >= date '1994-01-01'
and l_shipdate < date '1995-01-01'
and l_discount > 0.06 - 0.01 and l_discount < 0.06 + 0.01
Copy link
Contributor

@Dandandan Dandandan Dec 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this have between that was also added recently? @seddonm1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sorry i will fix this.

Copy link
Contributor Author

@seddonm1 seddonm1 Dec 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, no. The raw query as per TPC-H does not use between for that clause:

where l_shipdate >= date '[DATE]' 
and l_shipdate < date '[DATE]' + interval '1' year
and l_discount between [DISCOUNT] - 0.01 and [DISCOUNT] + 0.01

Which is different as BETWEEN is inclusive (>= AND <=)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a between on the last line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, sorry I thought you meant add it to the l_shipdate component. Yes, i will fix that.

and l_quantity < 24;"
),

Expand Down Expand Up @@ -399,7 +399,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY')
or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')
)
and l_shipdate > '1995-01-01' and l_shipdate < '1996-12-31'
and l_shipdate > date '1995-01-01' and l_shipdate < date '1996-12-31'
) as shipping
group by
supp_nation,
Expand Down Expand Up @@ -442,7 +442,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
and n1.n_regionkey = r_regionkey
and r_name = 'AMERICA'
and s_nationkey = n2.n_nationkey
and o_orderdate between '1995-01-01' and '1996-12-31'
and o_orderdate between date '1995-01-01' and date '1996-12-31'
and p_type = 'ECONOMY ANODIZED STEEL'
) as all_nations
group by
Expand Down Expand Up @@ -486,6 +486,39 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
o_year desc;"
),

// 10 => ctx.create_logical_plan(
// "select
// c_custkey,
// c_name,
// sum(l_extendedprice * (1 - l_discount)) as revenue,
// c_acctbal,
// n_name,
// c_address,
// c_phone,
// c_comment
// from
// customer,
// orders,
// lineitem,
// nation
// where
// c_custkey = o_custkey
// and l_orderkey = o_orderkey
// and o_orderdate >= date '1993-10-01'
// and o_orderdate < date '1993-10-01' + interval '3' month
// and l_returnflag = 'R'
// and c_nationkey = n_nationkey
// group by
// c_custkey,
// c_name,
// c_acctbal,
// c_phone,
// n_name,
// c_address,
// c_comment
// order by
// revenue desc;"
// ),
10 => ctx.create_logical_plan(
"select
c_custkey,
Expand All @@ -504,8 +537,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
where
c_custkey = o_custkey
and l_orderkey = o_orderkey
and o_orderdate >= '1993-10-01'
and o_orderdate < '1994-01-01'
and o_orderdate >= date '1993-10-01'
and o_orderdate < date '1994-01-01'
and l_returnflag = 'R'
and c_nationkey = n_nationkey
group by
Expand Down Expand Up @@ -606,8 +639,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
(l_shipmode = 'MAIL' or l_shipmode = 'SHIP')
and l_commitdate < l_receiptdate
and l_shipdate < l_commitdate
and l_receiptdate >= '1994-01-01'
and l_receiptdate < '1995-01-01'
and l_receiptdate >= date '1994-01-01'
and l_receiptdate < date '1995-01-01'
group by
l_shipmode
order by
Expand Down Expand Up @@ -649,8 +682,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
part
where
l_partkey = p_partkey
and l_shipdate >= '1995-09-01'
and l_shipdate < '1995-10-01';"
and l_shipdate >= date '1995-09-01'
and l_shipdate < date '1995-10-01';"
),

15 => ctx.create_logical_plan(
Expand Down Expand Up @@ -1072,7 +1105,7 @@ fn get_schema(table: &str) -> Schema {
Field::new("o_custkey", DataType::UInt32, false),
Field::new("o_orderstatus", DataType::Utf8, false),
Field::new("o_totalprice", DataType::Float64, false), // decimal
Field::new("o_orderdate", DataType::Utf8, false),
Field::new("o_orderdate", DataType::Date32(DateUnit::Day), false),
Field::new("o_orderpriority", DataType::Utf8, false),
Field::new("o_clerk", DataType::Utf8, false),
Field::new("o_shippriority", DataType::UInt32, false),
Expand All @@ -1090,9 +1123,9 @@ fn get_schema(table: &str) -> Schema {
Field::new("l_tax", DataType::Float64, false), // decimal
Field::new("l_returnflag", DataType::Utf8, false),
Field::new("l_linestatus", DataType::Utf8, false),
Field::new("l_shipdate", DataType::Utf8, false),
Field::new("l_commitdate", DataType::Utf8, false),
Field::new("l_receiptdate", DataType::Utf8, false),
Field::new("l_shipdate", DataType::Date32(DateUnit::Day), false),
Field::new("l_commitdate", DataType::Date32(DateUnit::Day), false),
Field::new("l_receiptdate", DataType::Date32(DateUnit::Day), false),
Field::new("l_shipinstruct", DataType::Utf8, false),
Field::new("l_shipmode", DataType::Utf8, false),
Field::new("l_comment", DataType::Utf8, false),
Expand Down
72 changes: 69 additions & 3 deletions rust/datafusion/src/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ use arrow::datatypes::{DataType, DateUnit, Schema, TimeUnit};
use arrow::record_batch::RecordBatch;
use arrow::{
array::{
ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array, Int16Array,
Int32Array, Int64Array, Int8Array, StringArray, TimestampNanosecondArray,
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
},
datatypes::Field,
};
Expand Down Expand Up @@ -1135,6 +1135,9 @@ macro_rules! binary_array_op {
DataType::Date32(DateUnit::Day) => {
compute_op!($LEFT, $RIGHT, $OP, Date32Array)
}
DataType::Date64(DateUnit::Millisecond) => {
compute_op!($LEFT, $RIGHT, $OP, Date64Array)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?}",
other
Expand Down Expand Up @@ -1227,6 +1230,19 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
}
}

/// Coercion rules for Temporal columns: the type that both lhs and rhs can be
/// casted to for the purpose of a date computation
fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Utf8, Date32(DateUnit::Day)) => Some(Date32(DateUnit::Day)),
(Date32(DateUnit::Day), Utf8) => Some(Date32(DateUnit::Day)),
(Utf8, Date64(DateUnit::Millisecond)) => Some(Date64(DateUnit::Millisecond)),
(Date64(DateUnit::Millisecond), Utf8) => Some(Date64(DateUnit::Millisecond)),
_ => None,
}
}

/// Coercion rule for numerical types: The type that both lhs and rhs
/// can be casted to for numerical calculation, while maintaining
/// maximum precision
Expand Down Expand Up @@ -1288,6 +1304,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
}
numerical_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
}

// coercion rules that assume an ordered set, such as "less than".
Expand All @@ -1301,6 +1318,7 @@ fn order_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
numerical_coercion(lhs_type, rhs_type)
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
}

/// Coercion rules for all binary operators. Returns the output type
Expand Down Expand Up @@ -2629,6 +2647,54 @@ mod tests {
DataType::Boolean,
vec![true, false]
);
test_coercion!(
StringArray,
DataType::Utf8,
vec!["1994-12-13", "1995-01-26"],
Date32Array,
DataType::Date32(DateUnit::Day),
vec![9112, 9156],
Operator::Eq,
BooleanArray,
DataType::Boolean,
vec![true, true]
);
test_coercion!(
StringArray,
DataType::Utf8,
vec!["1994-12-13", "1995-01-26"],
Date32Array,
DataType::Date32(DateUnit::Day),
vec![9113, 9154],
Operator::Lt,
BooleanArray,
DataType::Boolean,
vec![true, false]
);
test_coercion!(
StringArray,
DataType::Utf8,
vec!["1994-12-13", "1995-01-26"],
Date64Array,
DataType::Date64(DateUnit::Millisecond),
vec![787276800000, 791078400000],
Operator::Eq,
BooleanArray,
DataType::Boolean,
vec![true, true]
);
test_coercion!(
StringArray,
DataType::Utf8,
vec!["1994-12-13", "1995-01-26"],
Date64Array,
DataType::Date64(DateUnit::Millisecond),
vec![787276800001, 791078399999],
Operator::Lt,
BooleanArray,
DataType::Boolean,
vec![true, false]
);
Ok(())
}

Expand Down
16 changes: 16 additions & 0 deletions rust/datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,14 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
data_type: convert_data_type(data_type)?,
}),

SQLExpr::TypedString {
ref data_type,
ref value,
} => Ok(Expr::Cast {
expr: Box::new(lit(&**value)),
data_type: convert_data_type(data_type)?,
}),

SQLExpr::IsNull(ref expr) => {
Ok(Expr::IsNull(Box::new(self.sql_to_rex(expr, schema)?)))
}
Expand Down Expand Up @@ -1311,6 +1319,14 @@ mod tests {
quick_test(sql, expected);
}

#[test]
fn select_typedstring() {
let sql = "SELECT date '2020-12-10' AS date FROM person";
let expected = "Projection: CAST(Utf8(\"2020-12-10\") AS Date32(Day)) AS date\
\n TableScan: person projection=None";
quick_test(sql, expected);
}

fn logical_plan(sql: &str) -> Result<LogicalPlan> {
let planner = SqlToRel::new(&MockSchemaProvider {});
let result = DFParser::parse_sql(&sql);
Expand Down