-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Refactor range/gen_series signature away from user defined
#18317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -311,10 +311,7 @@ impl TypeSignatureClass { | |
| } | ||
|
|
||
| /// Does the specified `NativeType` match this type signature class? | ||
| pub fn matches_native_type( | ||
| self: &TypeSignatureClass, | ||
| logical_type: &NativeType, | ||
| ) -> bool { | ||
| pub fn matches_native_type(&self, logical_type: &NativeType) -> bool { | ||
| if logical_type == &NativeType::Null { | ||
| return true; | ||
| } | ||
|
|
@@ -360,6 +357,7 @@ impl TypeSignatureClass { | |
| TypeSignatureClass::Binary if native_type.is_binary() => { | ||
| Ok(origin_type.to_owned()) | ||
| } | ||
| _ if native_type.is_null() => Ok(origin_type.to_owned()), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We were missing this even though we check for it in |
||
| _ => internal_err!("May miss the matching logic in `matches_native_type`"), | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,33 +18,39 @@ | |
| //! [`ScalarUDFImpl`] definitions for range and gen_series functions. | ||
|
|
||
| use crate::utils::make_scalar_function; | ||
| use arrow::array::{ | ||
| builder::{Date32Builder, TimestampNanosecondBuilder}, | ||
| temporal_conversions::as_datetime_with_timezone, | ||
| timezone::Tz, | ||
| types::{Date32Type, IntervalMonthDayNanoType, TimestampNanosecondType}, | ||
| Array, ArrayRef, Int64Array, ListArray, ListBuilder, NullBufferBuilder, | ||
| }; | ||
| use arrow::buffer::OffsetBuffer; | ||
| use arrow::datatypes::{ | ||
| DataType, DataType::*, Field, IntervalUnit::MonthDayNano, TimeUnit::Nanosecond, | ||
| use arrow::datatypes::TimeUnit; | ||
| use arrow::datatypes::{DataType, Field, IntervalUnit::MonthDayNano}; | ||
| use arrow::{ | ||
| array::{ | ||
| builder::{Date32Builder, TimestampNanosecondBuilder}, | ||
| temporal_conversions::as_datetime_with_timezone, | ||
| timezone::Tz, | ||
| types::{Date32Type, IntervalMonthDayNanoType, TimestampNanosecondType}, | ||
| Array, ArrayRef, Int64Array, ListArray, ListBuilder, NullBufferBuilder, | ||
| }, | ||
| compute::cast, | ||
| }; | ||
| use datafusion_common::internal_err; | ||
| use datafusion_common::{ | ||
| cast::{ | ||
| as_date32_array, as_int64_array, as_interval_mdn_array, | ||
| as_timestamp_nanosecond_array, | ||
| }, | ||
| DataFusionError, ScalarValue, | ||
| types::{ | ||
| logical_date, logical_int64, logical_interval_mdn, logical_string, NativeType, | ||
| }, | ||
| ScalarValue, | ||
| }; | ||
| use datafusion_common::{ | ||
| exec_datafusion_err, exec_err, not_impl_datafusion_err, utils::take_function_args, | ||
| Result, | ||
| }; | ||
| use datafusion_expr::{ | ||
| ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, | ||
| Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, | ||
| TypeSignatureClass, Volatility, | ||
| }; | ||
| use datafusion_macros::user_doc; | ||
| use itertools::Itertools; | ||
| use std::any::Any; | ||
| use std::cmp::Ordering; | ||
| use std::iter::from_fn; | ||
|
|
@@ -146,18 +152,60 @@ impl Default for Range { | |
| } | ||
|
|
||
| impl Range { | ||
| fn defined_signature() -> Signature { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the main change |
||
| // We natively only support i64 in our implementation; so ensure we cast other integer | ||
| // types to it. | ||
| let integer = Coercion::new_implicit( | ||
| TypeSignatureClass::Native(logical_int64()), | ||
| vec![TypeSignatureClass::Integer], | ||
| NativeType::Int64, | ||
| ); | ||
| // We natively only support mdn in our implementation; so ensure we cast other interval | ||
| // types to it. | ||
| let interval = Coercion::new_implicit( | ||
| TypeSignatureClass::Native(logical_interval_mdn()), | ||
| vec![TypeSignatureClass::Interval], | ||
| NativeType::Interval(MonthDayNano), | ||
| ); | ||
| // Ideally we'd limit to only Date32 & Timestamp(Nanoseconds) as those are the implementations | ||
| // we have but that is difficult to do with this current API; we'll cast later on to | ||
| // handle such types. | ||
| let date = Coercion::new_implicit( | ||
| TypeSignatureClass::Native(logical_date()), | ||
| vec![TypeSignatureClass::Native(logical_string())], | ||
| NativeType::Date, | ||
| ); | ||
| let timestamp = Coercion::new_exact(TypeSignatureClass::Timestamp); | ||
| Signature::one_of( | ||
| vec![ | ||
| // Integer ranges | ||
| // Stop | ||
| TypeSignature::Coercible(vec![integer.clone()]), | ||
| // Start & stop | ||
| TypeSignature::Coercible(vec![integer.clone(), integer.clone()]), | ||
| // Start, stop & step | ||
| TypeSignature::Coercible(vec![integer.clone(), integer.clone(), integer]), | ||
| // Date range | ||
| TypeSignature::Coercible(vec![date.clone(), date, interval.clone()]), | ||
| // Timestamp range | ||
| TypeSignature::Coercible(vec![timestamp.clone(), timestamp, interval]), | ||
| ], | ||
| Volatility::Immutable, | ||
| ) | ||
| } | ||
|
|
||
| /// Generate `range()` function which excludes upper bound. | ||
| pub fn new() -> Self { | ||
| Self { | ||
| signature: Signature::user_defined(Volatility::Immutable), | ||
| signature: Self::defined_signature(), | ||
| include_upper_bound: false, | ||
| } | ||
| } | ||
|
|
||
| /// Generate `generate_series()` function which includes upper bound. | ||
| fn generate_series() -> Self { | ||
| Self { | ||
| signature: Signature::user_defined(Volatility::Immutable), | ||
| signature: Self::defined_signature(), | ||
| include_upper_bound: true, | ||
| } | ||
| } | ||
|
|
@@ -180,39 +228,27 @@ impl ScalarUDFImpl for Range { | |
| &self.signature | ||
| } | ||
|
|
||
| fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { | ||
| arg_types | ||
| .iter() | ||
| .map(|arg_type| match arg_type { | ||
| Null => Ok(Null), | ||
| Int8 => Ok(Int64), | ||
| Int16 => Ok(Int64), | ||
| Int32 => Ok(Int64), | ||
| Int64 => Ok(Int64), | ||
| UInt8 => Ok(Int64), | ||
| UInt16 => Ok(Int64), | ||
| UInt32 => Ok(Int64), | ||
| UInt64 => Ok(Int64), | ||
| Timestamp(_, tz) => Ok(Timestamp(Nanosecond, tz.clone())), | ||
| Date32 => Ok(Date32), | ||
| Date64 => Ok(Date32), | ||
| Utf8 => Ok(Date32), | ||
| LargeUtf8 => Ok(Date32), | ||
| Utf8View => Ok(Date32), | ||
| Interval(_) => Ok(Interval(MonthDayNano)), | ||
| _ => exec_err!("Unsupported DataType"), | ||
| }) | ||
| .try_collect() | ||
| } | ||
|
|
||
| fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
| if arg_types.iter().any(|t| t.is_null()) { | ||
| Ok(Null) | ||
| } else { | ||
| Ok(List(Arc::new(Field::new_list_field( | ||
| return Ok(DataType::Null); | ||
| } | ||
|
|
||
| match (&arg_types[0], arg_types.get(1)) { | ||
| // In implementation we downcast to Date32 so ensure reflect that here | ||
| (_, Some(DataType::Date64)) | (DataType::Date64, _) => Ok(DataType::List( | ||
| Arc::new(Field::new_list_field(DataType::Date32, true)), | ||
| )), | ||
| // Ensure we preserve timezone | ||
| (DataType::Timestamp(_, tz), _) => { | ||
| Ok(DataType::List(Arc::new(Field::new_list_field( | ||
| DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned()), | ||
| true, | ||
| )))) | ||
| } | ||
| _ => Ok(DataType::List(Arc::new(Field::new_list_field( | ||
| arg_types[0].clone(), | ||
| true, | ||
| )))) | ||
| )))), | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -226,13 +262,20 @@ impl ScalarUDFImpl for Range { | |
| return Ok(ColumnarValue::Scalar(ScalarValue::Null)); | ||
| } | ||
| match args[0].data_type() { | ||
| Int64 => make_scalar_function(|args| self.gen_range_inner(args))(args), | ||
| Date32 => make_scalar_function(|args| self.gen_range_date(args))(args), | ||
| Timestamp(_, _) => { | ||
| DataType::Int64 => { | ||
| make_scalar_function(|args| self.gen_range_inner(args))(args) | ||
| } | ||
| DataType::Date32 | DataType::Date64 => { | ||
| make_scalar_function(|args| self.gen_range_date(args))(args) | ||
| } | ||
| DataType::Timestamp(_, _) => { | ||
| make_scalar_function(|args| self.gen_range_timestamp(args))(args) | ||
| } | ||
| dt => { | ||
| exec_err!("unsupported type for {}. Expected Int64, Date32 or Timestamp, got: {dt}", self.name()) | ||
| internal_err!( | ||
| "Signature failed to guard unknown input type for {}: {dt}", | ||
| self.name() | ||
| ) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -274,7 +317,7 @@ impl Range { | |
| as_int64_array(stop_array)?, | ||
| Some(as_int64_array(step_array)?), | ||
| ), | ||
| _ => return exec_err!("{} expects 1 to 3 arguments", self.name()), | ||
| _ => return internal_err!("{} expects 1 to 3 arguments", self.name()), | ||
| }; | ||
|
|
||
| let mut values = vec![]; | ||
|
|
@@ -310,7 +353,7 @@ impl Range { | |
| }; | ||
| } | ||
| let arr = Arc::new(ListArray::try_new( | ||
| Arc::new(Field::new_list_field(Int64, true)), | ||
| Arc::new(Field::new_list_field(DataType::Int64, true)), | ||
| OffsetBuffer::new(offsets.into()), | ||
| Arc::new(Int64Array::from(values)), | ||
| valid.finish(), | ||
|
|
@@ -320,29 +363,28 @@ impl Range { | |
|
|
||
| fn gen_range_date(&self, args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [start, stop, step] = take_function_args(self.name(), args)?; | ||
| let step = as_interval_mdn_array(step)?; | ||
|
|
||
| let (start_array, stop_array, step_array) = ( | ||
| as_date32_array(start)?, | ||
| as_date32_array(stop)?, | ||
| as_interval_mdn_array(step)?, | ||
| ); | ||
| // Signature can only guarantee we get a date type, not specifically | ||
| // date32 so handle potential cast from date64 here. | ||
| let start = cast(start, &DataType::Date32)?; | ||
| let start = as_date32_array(&start)?; | ||
| let stop = cast(stop, &DataType::Date32)?; | ||
| let stop = as_date32_array(&stop)?; | ||
|
Comment on lines
+368
to
+373
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see much benefit in handling date64 natively anyway (given it's a bit of a weird type) so I think this is fine |
||
|
|
||
| // values are date32s | ||
| let values_builder = Date32Builder::new(); | ||
| let mut list_builder = ListBuilder::new(values_builder); | ||
|
|
||
| for idx in 0..stop_array.len() { | ||
| if start_array.is_null(idx) | ||
| || stop_array.is_null(idx) | ||
| || step_array.is_null(idx) | ||
| { | ||
| for idx in 0..stop.len() { | ||
| if start.is_null(idx) || stop.is_null(idx) || step.is_null(idx) { | ||
| list_builder.append_null(); | ||
| continue; | ||
| } | ||
|
|
||
| let start = start_array.value(idx); | ||
| let stop = stop_array.value(idx); | ||
| let step = step_array.value(idx); | ||
| let start = start.value(idx); | ||
| let stop = stop.value(idx); | ||
| let step = step.value(idx); | ||
|
|
||
| let (months, days, _) = IntervalMonthDayNanoType::to_parts(step); | ||
| if months == 0 && days == 0 { | ||
|
|
@@ -378,44 +420,45 @@ impl Range { | |
|
|
||
| fn gen_range_timestamp(&self, args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [start, stop, step] = take_function_args(self.name(), args)?; | ||
| let step = as_interval_mdn_array(step)?; | ||
|
|
||
| // Signature can only guarantee we get a timestamp type, not specifically | ||
| // timestamp(ns) so handle potential cast from other timestamps here. | ||
| fn cast_to_ns(arr: &ArrayRef) -> Result<ArrayRef> { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. given
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it's to cast from potentially other timestamp time units down to nanoseconds |
||
| match arr.data_type() { | ||
| DataType::Timestamp(TimeUnit::Nanosecond, _) => Ok(Arc::clone(arr)), | ||
| DataType::Timestamp(_, tz) => Ok(cast( | ||
| arr, | ||
| &DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned()), | ||
| )?), | ||
| _ => unreachable!(), | ||
| } | ||
| } | ||
| let start = cast_to_ns(start)?; | ||
| let start = as_timestamp_nanosecond_array(&start)?; | ||
| let stop = cast_to_ns(stop)?; | ||
| let stop = as_timestamp_nanosecond_array(&stop)?; | ||
|
Comment on lines
+425
to
+440
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if is better to instead try uplift the signature coercible API to allow specifying we only want timestamp(ns) (of any timezone) and let that handle coercion for us 🤔
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that would be nicer but I think it could also be done as a follow on PR |
||
|
|
||
| // coerce_types fn should coerce all types to Timestamp(Nanosecond, tz) | ||
| // TODO: remove these map_err once the signature is robust enough to guard against this | ||
| let start_arr = as_timestamp_nanosecond_array(start).map_err(|_e| { | ||
| DataFusionError::Internal(format!( | ||
| "Unexpected argument type for {} : {}", | ||
| self.name(), | ||
| start.data_type() | ||
| )) | ||
| })?; | ||
| let stop_arr = as_timestamp_nanosecond_array(stop).map_err(|_e| { | ||
| DataFusionError::Internal(format!( | ||
| "Unexpected argument type for {} : {}", | ||
| self.name(), | ||
| stop.data_type() | ||
| )) | ||
| })?; | ||
| let step_arr = as_interval_mdn_array(step)?; | ||
| let start_tz = parse_tz(&start_arr.timezone())?; | ||
| let stop_tz = parse_tz(&stop_arr.timezone())?; | ||
| let start_tz = parse_tz(&start.timezone())?; | ||
| let stop_tz = parse_tz(&stop.timezone())?; | ||
|
|
||
| // values are timestamps | ||
| let values_builder = start_arr | ||
| let values_builder = start | ||
| .timezone() | ||
| .map_or_else(TimestampNanosecondBuilder::new, |start_tz_str| { | ||
| TimestampNanosecondBuilder::new().with_timezone(start_tz_str) | ||
| }); | ||
| let mut list_builder = ListBuilder::new(values_builder); | ||
|
|
||
| for idx in 0..start_arr.len() { | ||
| if start_arr.is_null(idx) || stop_arr.is_null(idx) || step_arr.is_null(idx) { | ||
| for idx in 0..start.len() { | ||
| if start.is_null(idx) || stop.is_null(idx) || step.is_null(idx) { | ||
| list_builder.append_null(); | ||
| continue; | ||
| } | ||
|
|
||
| let start = start_arr.value(idx); | ||
| let stop = stop_arr.value(idx); | ||
| let step = step_arr.value(idx); | ||
| let start = start.value(idx); | ||
| let stop = stop.value(idx); | ||
| let step = step.value(idx); | ||
|
|
||
| let (months, days, ns) = IntervalMonthDayNanoType::to_parts(step); | ||
| if months == 0 && days == 0 && ns == 0 { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you potentially add some documentation here about what is different about this macro compared to the others in this module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done 👍