Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,17 @@ fn group_aggregate_batch(
// this is an optimization to avoid allocating `key` on every row.
// it will be overwritten on every iteration of the loop below
let mut group_by_values = Vec::with_capacity(group_values.len());
let mut null_information = Vec::with_capacity(group_values.len());
/*
for _ in 0..group_values.len() {
group_by_values.push(ScalarValue::UInt32(Some(0)));
}
*/

group_values.iter().for_each(|array| {
null_information.push(array.null_count() > 0);
group_by_values.push(ScalarValue::UInt32(Some(0)));
});

// 1.1 construct the key from the group values
// 1.2 construct the mapping key if it does not exist
Expand All @@ -368,7 +376,10 @@ fn group_aggregate_batch(
group_values
.iter()
.zip(group_state.group_by_values.iter())
.all(|(array, scalar)| scalar.eq_array(array, row))
.zip(null_information.iter())
.all(|((array, scalar), has_nulls)| {
scalar.eq_array(array, row, has_nulls)
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think broadly speaking this is the pattern, though I am not sure how well the compiler will be able to optimize given a has nulls parameter has_nulls -- it might need to be hoisted out of this loop (the call to all here)

Have you had a chance to try profiling?

Copy link
Author

@novemberkilo novemberkilo Aug 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi I don't understand the suggestion about hoisting has_nulls out of the all

Would you be able to please sketch it out in pseudo code or similar?

I haven't used macros much and will go educate myself on how they work (particularly the optimisation role of the compiler). Will also get to profiling (although maybe not until the weekend).

});

match entry {
Expand Down
99 changes: 71 additions & 28 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,25 @@ macro_rules! build_array_from_option {
}};
}

macro_rules! eq_array_general {
($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr, $has_nulls:expr) => {{
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
if *$has_nulls {
Copy link
Contributor

@Dandandan Dandandan Sep 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch should be ideally be outside of the loop of
https://github.com/apache/arrow-datafusion/pull/922/files#diff-03876812a8bef4074e517600fdcf8e6b49f1ea24df44905d6d806836fd61b2a8L360

But this might be hard to accomplish given the current design. I posted some ideas my earlier comment on this PR

let is_valid = array.is_valid($index);
match $VALUE {
Some(val) => is_valid && &array.value($index) == val,
None => !is_valid,
}
} else {
match $VALUE {
Some(val) => &array.value($index) == val,
None => false,
}
}
}};
}

/*
macro_rules! eq_array_primitive {
($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
Expand All @@ -435,6 +454,17 @@ macro_rules! eq_array_primitive {
}};
}

macro_rules! eq_array_no_nulls_primitive {
($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
match $VALUE {
Some(val) => &array.value($index) == val,
None => false,
}
}};
}
*/

impl ScalarValue {
/// Getter for the `DataType` of the value
pub fn get_datatype(&self) -> DataType {
Expand Down Expand Up @@ -1028,69 +1058,81 @@ impl ScalarValue {
/// This function has a few narrow usescases such as hash table key
/// comparisons where comparing a single row at a time is necessary.
#[inline]
pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool {
pub fn eq_array(&self, array: &ArrayRef, index: usize, has_nulls: &bool) -> bool {
if let DataType::Dictionary(key_type, _) = array.data_type() {
return self.eq_array_dictionary(array, index, key_type);
return self.eq_array_dictionary(array, index, key_type, has_nulls);
}

match self {
ScalarValue::Boolean(val) => {
eq_array_primitive!(array, index, BooleanArray, val)
eq_array_general!(array, index, BooleanArray, val, has_nulls)
}
ScalarValue::Float32(val) => {
eq_array_primitive!(array, index, Float32Array, val)
eq_array_general!(array, index, Float32Array, val, has_nulls)
}
ScalarValue::Float64(val) => {
eq_array_primitive!(array, index, Float64Array, val)
eq_array_general!(array, index, Float64Array, val, has_nulls)
}
ScalarValue::Int8(val) => {
eq_array_general!(array, index, Int8Array, val, has_nulls)
}
ScalarValue::Int16(val) => {
eq_array_general!(array, index, Int16Array, val, has_nulls)
}
ScalarValue::Int32(val) => {
eq_array_general!(array, index, Int32Array, val, has_nulls)
}
ScalarValue::Int64(val) => {
eq_array_general!(array, index, Int64Array, val, has_nulls)
}
ScalarValue::UInt8(val) => {
eq_array_general!(array, index, UInt8Array, val, has_nulls)
}
ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val),
ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val),
ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val),
ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val),
ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val),
ScalarValue::UInt16(val) => {
eq_array_primitive!(array, index, UInt16Array, val)
eq_array_general!(array, index, UInt16Array, val, has_nulls)
}
ScalarValue::UInt32(val) => {
eq_array_primitive!(array, index, UInt32Array, val)
eq_array_general!(array, index, UInt32Array, val, has_nulls)
}
ScalarValue::UInt64(val) => {
eq_array_primitive!(array, index, UInt64Array, val)
eq_array_general!(array, index, UInt64Array, val, has_nulls)
}
ScalarValue::Utf8(val) => {
eq_array_general!(array, index, StringArray, val, has_nulls)
}
ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val),
ScalarValue::LargeUtf8(val) => {
eq_array_primitive!(array, index, LargeStringArray, val)
eq_array_general!(array, index, LargeStringArray, val, has_nulls)
}
ScalarValue::Binary(val) => {
eq_array_primitive!(array, index, BinaryArray, val)
eq_array_general!(array, index, BinaryArray, val, has_nulls)
}
ScalarValue::LargeBinary(val) => {
eq_array_primitive!(array, index, LargeBinaryArray, val)
eq_array_general!(array, index, LargeBinaryArray, val, has_nulls)
}
ScalarValue::List(_, _) => unimplemented!(),
ScalarValue::Date32(val) => {
eq_array_primitive!(array, index, Date32Array, val)
eq_array_general!(array, index, Date32Array, val, has_nulls)
}
ScalarValue::Date64(val) => {
eq_array_primitive!(array, index, Date64Array, val)
eq_array_general!(array, index, Date64Array, val, has_nulls)
}
ScalarValue::TimestampSecond(val) => {
eq_array_primitive!(array, index, TimestampSecondArray, val)
eq_array_general!(array, index, TimestampSecondArray, val, has_nulls)
}
ScalarValue::TimestampMillisecond(val) => {
eq_array_primitive!(array, index, TimestampMillisecondArray, val)
eq_array_general!(array, index, TimestampMillisecondArray, val, has_nulls)
}
ScalarValue::TimestampMicrosecond(val) => {
eq_array_primitive!(array, index, TimestampMicrosecondArray, val)
eq_array_general!(array, index, TimestampMicrosecondArray, val, has_nulls)
}
ScalarValue::TimestampNanosecond(val) => {
eq_array_primitive!(array, index, TimestampNanosecondArray, val)
eq_array_general!(array, index, TimestampNanosecondArray, val, has_nulls)
}
ScalarValue::IntervalYearMonth(val) => {
eq_array_primitive!(array, index, IntervalYearMonthArray, val)
eq_array_general!(array, index, IntervalYearMonthArray, val, has_nulls)
}
ScalarValue::IntervalDayTime(val) => {
eq_array_primitive!(array, index, IntervalDayTimeArray, val)
eq_array_general!(array, index, IntervalDayTimeArray, val, has_nulls)
}
}
}
Expand All @@ -1102,6 +1144,7 @@ impl ScalarValue {
array: &ArrayRef,
index: usize,
key_type: &DataType,
has_nulls: &bool,
) -> bool {
let (values, values_index) = match key_type {
DataType::Int8 => get_dict_value::<Int8Type>(array, index).unwrap(),
Expand All @@ -1116,7 +1159,7 @@ impl ScalarValue {
};

match values_index {
Some(values_index) => self.eq_array(values, values_index),
Some(values_index) => self.eq_array(values, values_index, has_nulls),
None => self.is_null(),
}
}
Expand Down Expand Up @@ -1914,7 +1957,7 @@ mod tests {

for (index, scalar) in scalars.into_iter().enumerate() {
assert!(
scalar.eq_array(&array, index),
scalar.eq_array(&array, index, &true),
"Expected {:?} to be equal to {:?} at index {}",
scalar,
array,
Expand All @@ -1925,7 +1968,7 @@ mod tests {
for other_index in 0..array.len() {
if index != other_index {
assert!(
!scalar.eq_array(&array, other_index),
!scalar.eq_array(&array, other_index, &true),
"Expected {:?} to be NOT equal to {:?} at index {}",
scalar,
array,
Expand Down