-
Notifications
You must be signed in to change notification settings - Fork 1.9k
WIP Optimize hash_aggregate when there are no null group keys #922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -424,6 +424,25 @@ macro_rules! build_array_from_option { | |
| }}; | ||
| } | ||
|
|
||
| macro_rules! eq_array_general { | ||
| ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr, $has_nulls:expr) => {{ | ||
| let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); | ||
| if *$has_nulls { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This branch should be ideally be outside of the loop of But this might be hard to accomplish given the current design. I posted some ideas my earlier comment on this PR |
||
| let is_valid = array.is_valid($index); | ||
| match $VALUE { | ||
| Some(val) => is_valid && &array.value($index) == val, | ||
| None => !is_valid, | ||
| } | ||
| } else { | ||
| match $VALUE { | ||
| Some(val) => &array.value($index) == val, | ||
| None => false, | ||
| } | ||
| } | ||
| }}; | ||
| } | ||
|
|
||
| /* | ||
| macro_rules! eq_array_primitive { | ||
| ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ | ||
| let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); | ||
|
|
@@ -435,6 +454,17 @@ macro_rules! eq_array_primitive { | |
| }}; | ||
| } | ||
|
|
||
| macro_rules! eq_array_no_nulls_primitive { | ||
| ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ | ||
| let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); | ||
| match $VALUE { | ||
| Some(val) => &array.value($index) == val, | ||
| None => false, | ||
| } | ||
| }}; | ||
| } | ||
| */ | ||
|
|
||
| impl ScalarValue { | ||
| /// Getter for the `DataType` of the value | ||
| pub fn get_datatype(&self) -> DataType { | ||
|
|
@@ -1028,69 +1058,81 @@ impl ScalarValue { | |
| /// This function has a few narrow usescases such as hash table key | ||
| /// comparisons where comparing a single row at a time is necessary. | ||
| #[inline] | ||
| pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { | ||
| pub fn eq_array(&self, array: &ArrayRef, index: usize, has_nulls: &bool) -> bool { | ||
| if let DataType::Dictionary(key_type, _) = array.data_type() { | ||
| return self.eq_array_dictionary(array, index, key_type); | ||
| return self.eq_array_dictionary(array, index, key_type, has_nulls); | ||
| } | ||
|
|
||
| match self { | ||
| ScalarValue::Boolean(val) => { | ||
| eq_array_primitive!(array, index, BooleanArray, val) | ||
| eq_array_general!(array, index, BooleanArray, val, has_nulls) | ||
| } | ||
| ScalarValue::Float32(val) => { | ||
| eq_array_primitive!(array, index, Float32Array, val) | ||
| eq_array_general!(array, index, Float32Array, val, has_nulls) | ||
| } | ||
| ScalarValue::Float64(val) => { | ||
| eq_array_primitive!(array, index, Float64Array, val) | ||
| eq_array_general!(array, index, Float64Array, val, has_nulls) | ||
| } | ||
| ScalarValue::Int8(val) => { | ||
| eq_array_general!(array, index, Int8Array, val, has_nulls) | ||
| } | ||
| ScalarValue::Int16(val) => { | ||
| eq_array_general!(array, index, Int16Array, val, has_nulls) | ||
| } | ||
| ScalarValue::Int32(val) => { | ||
| eq_array_general!(array, index, Int32Array, val, has_nulls) | ||
| } | ||
| ScalarValue::Int64(val) => { | ||
| eq_array_general!(array, index, Int64Array, val, has_nulls) | ||
| } | ||
| ScalarValue::UInt8(val) => { | ||
| eq_array_general!(array, index, UInt8Array, val, has_nulls) | ||
| } | ||
| ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), | ||
| ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), | ||
| ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), | ||
| ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), | ||
| ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), | ||
| ScalarValue::UInt16(val) => { | ||
| eq_array_primitive!(array, index, UInt16Array, val) | ||
| eq_array_general!(array, index, UInt16Array, val, has_nulls) | ||
| } | ||
| ScalarValue::UInt32(val) => { | ||
| eq_array_primitive!(array, index, UInt32Array, val) | ||
| eq_array_general!(array, index, UInt32Array, val, has_nulls) | ||
| } | ||
| ScalarValue::UInt64(val) => { | ||
| eq_array_primitive!(array, index, UInt64Array, val) | ||
| eq_array_general!(array, index, UInt64Array, val, has_nulls) | ||
| } | ||
| ScalarValue::Utf8(val) => { | ||
| eq_array_general!(array, index, StringArray, val, has_nulls) | ||
| } | ||
| ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), | ||
| ScalarValue::LargeUtf8(val) => { | ||
| eq_array_primitive!(array, index, LargeStringArray, val) | ||
| eq_array_general!(array, index, LargeStringArray, val, has_nulls) | ||
| } | ||
| ScalarValue::Binary(val) => { | ||
| eq_array_primitive!(array, index, BinaryArray, val) | ||
| eq_array_general!(array, index, BinaryArray, val, has_nulls) | ||
| } | ||
| ScalarValue::LargeBinary(val) => { | ||
| eq_array_primitive!(array, index, LargeBinaryArray, val) | ||
| eq_array_general!(array, index, LargeBinaryArray, val, has_nulls) | ||
| } | ||
| ScalarValue::List(_, _) => unimplemented!(), | ||
| ScalarValue::Date32(val) => { | ||
| eq_array_primitive!(array, index, Date32Array, val) | ||
| eq_array_general!(array, index, Date32Array, val, has_nulls) | ||
| } | ||
| ScalarValue::Date64(val) => { | ||
| eq_array_primitive!(array, index, Date64Array, val) | ||
| eq_array_general!(array, index, Date64Array, val, has_nulls) | ||
| } | ||
| ScalarValue::TimestampSecond(val) => { | ||
| eq_array_primitive!(array, index, TimestampSecondArray, val) | ||
| eq_array_general!(array, index, TimestampSecondArray, val, has_nulls) | ||
| } | ||
| ScalarValue::TimestampMillisecond(val) => { | ||
| eq_array_primitive!(array, index, TimestampMillisecondArray, val) | ||
| eq_array_general!(array, index, TimestampMillisecondArray, val, has_nulls) | ||
| } | ||
| ScalarValue::TimestampMicrosecond(val) => { | ||
| eq_array_primitive!(array, index, TimestampMicrosecondArray, val) | ||
| eq_array_general!(array, index, TimestampMicrosecondArray, val, has_nulls) | ||
| } | ||
| ScalarValue::TimestampNanosecond(val) => { | ||
| eq_array_primitive!(array, index, TimestampNanosecondArray, val) | ||
| eq_array_general!(array, index, TimestampNanosecondArray, val, has_nulls) | ||
| } | ||
| ScalarValue::IntervalYearMonth(val) => { | ||
| eq_array_primitive!(array, index, IntervalYearMonthArray, val) | ||
| eq_array_general!(array, index, IntervalYearMonthArray, val, has_nulls) | ||
| } | ||
| ScalarValue::IntervalDayTime(val) => { | ||
| eq_array_primitive!(array, index, IntervalDayTimeArray, val) | ||
| eq_array_general!(array, index, IntervalDayTimeArray, val, has_nulls) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -1102,6 +1144,7 @@ impl ScalarValue { | |
| array: &ArrayRef, | ||
| index: usize, | ||
| key_type: &DataType, | ||
| has_nulls: &bool, | ||
| ) -> bool { | ||
| let (values, values_index) = match key_type { | ||
| DataType::Int8 => get_dict_value::<Int8Type>(array, index).unwrap(), | ||
|
|
@@ -1116,7 +1159,7 @@ impl ScalarValue { | |
| }; | ||
|
|
||
| match values_index { | ||
| Some(values_index) => self.eq_array(values, values_index), | ||
| Some(values_index) => self.eq_array(values, values_index, has_nulls), | ||
| None => self.is_null(), | ||
| } | ||
| } | ||
|
|
@@ -1914,7 +1957,7 @@ mod tests { | |
|
|
||
| for (index, scalar) in scalars.into_iter().enumerate() { | ||
| assert!( | ||
| scalar.eq_array(&array, index), | ||
| scalar.eq_array(&array, index, &true), | ||
| "Expected {:?} to be equal to {:?} at index {}", | ||
| scalar, | ||
| array, | ||
|
|
@@ -1925,7 +1968,7 @@ mod tests { | |
| for other_index in 0..array.len() { | ||
| if index != other_index { | ||
| assert!( | ||
| !scalar.eq_array(&array, other_index), | ||
| !scalar.eq_array(&array, other_index, &true), | ||
| "Expected {:?} to be NOT equal to {:?} at index {}", | ||
| scalar, | ||
| array, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think broadly speaking this is the pattern, though I am not sure how well the compiler will be able to optimize given a has nulls parameter
has_nulls-- it might need to be hoisted out of this loop (the call toallhere)Have you had a chance to try profiling?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi I don't understand the suggestion about hoisting
has_nullsout of theallWould you be able to please sketch it out in pseudo code or similar?
I haven't used macros much and will go educate myself on how they work (particularly the optimisation role of the compiler). Will also get to profiling (although maybe not until the weekend).