diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 1c07f61f10cd5..77e73a39fe381 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -342,9 +342,17 @@ fn group_aggregate_batch( // this is an optimization to avoid allocating `key` on every row. // it will be overwritten on every iteration of the loop below let mut group_by_values = Vec::with_capacity(group_values.len()); + let mut null_information = Vec::with_capacity(group_values.len()); + /* for _ in 0..group_values.len() { group_by_values.push(ScalarValue::UInt32(Some(0))); } + */ + + group_values.iter().for_each(|array| { + null_information.push(array.null_count() > 0); + group_by_values.push(ScalarValue::UInt32(Some(0))); + }); // 1.1 construct the key from the group values // 1.2 construct the mapping key if it does not exist @@ -368,7 +376,10 @@ fn group_aggregate_batch( group_values .iter() .zip(group_state.group_by_values.iter()) - .all(|(array, scalar)| scalar.eq_array(array, row)) + .zip(null_information.iter()) + .all(|((array, scalar), has_nulls)| { + scalar.eq_array(array, row, has_nulls) + }) }); match entry { diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 86d17654c0604..abe0be42ae238 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -424,6 +424,25 @@ macro_rules! build_array_from_option { }}; } +macro_rules! eq_array_general { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr, $has_nulls:expr) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + if *$has_nulls { + let is_valid = array.is_valid($index); + match $VALUE { + Some(val) => is_valid && &array.value($index) == val, + None => !is_valid, + } + } else { + match $VALUE { + Some(val) => &array.value($index) == val, + None => false, + } + } + }}; +} + +/* macro_rules! eq_array_primitive { ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); @@ -435,6 +454,17 @@ macro_rules! eq_array_primitive { }}; } +macro_rules! eq_array_no_nulls_primitive { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + match $VALUE { + Some(val) => &array.value($index) == val, + None => false, + } + }}; +} +*/ + impl ScalarValue { /// Getter for the `DataType` of the value pub fn get_datatype(&self) -> DataType { @@ -1028,69 +1058,81 @@ impl ScalarValue { /// This function has a few narrow usescases such as hash table key /// comparisons where comparing a single row at a time is necessary. #[inline] - pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { + pub fn eq_array(&self, array: &ArrayRef, index: usize, has_nulls: &bool) -> bool { if let DataType::Dictionary(key_type, _) = array.data_type() { - return self.eq_array_dictionary(array, index, key_type); + return self.eq_array_dictionary(array, index, key_type, has_nulls); } match self { ScalarValue::Boolean(val) => { - eq_array_primitive!(array, index, BooleanArray, val) + eq_array_general!(array, index, BooleanArray, val, has_nulls) } ScalarValue::Float32(val) => { - eq_array_primitive!(array, index, Float32Array, val) + eq_array_general!(array, index, Float32Array, val, has_nulls) } ScalarValue::Float64(val) => { - eq_array_primitive!(array, index, Float64Array, val) + eq_array_general!(array, index, Float64Array, val, has_nulls) + } + ScalarValue::Int8(val) => { + eq_array_general!(array, index, Int8Array, val, has_nulls) + } + ScalarValue::Int16(val) => { + eq_array_general!(array, index, Int16Array, val, has_nulls) + } + ScalarValue::Int32(val) => { + eq_array_general!(array, index, Int32Array, val, has_nulls) + } + ScalarValue::Int64(val) => { + eq_array_general!(array, index, Int64Array, val, has_nulls) + } + ScalarValue::UInt8(val) => { + eq_array_general!(array, index, UInt8Array, val, has_nulls) } - ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), - ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), - ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), - ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), - ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), ScalarValue::UInt16(val) => { - eq_array_primitive!(array, index, UInt16Array, val) + eq_array_general!(array, index, UInt16Array, val, has_nulls) } ScalarValue::UInt32(val) => { - eq_array_primitive!(array, index, UInt32Array, val) + eq_array_general!(array, index, UInt32Array, val, has_nulls) } ScalarValue::UInt64(val) => { - eq_array_primitive!(array, index, UInt64Array, val) + eq_array_general!(array, index, UInt64Array, val, has_nulls) + } + ScalarValue::Utf8(val) => { + eq_array_general!(array, index, StringArray, val, has_nulls) } - ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val) + eq_array_general!(array, index, LargeStringArray, val, has_nulls) } ScalarValue::Binary(val) => { - eq_array_primitive!(array, index, BinaryArray, val) + eq_array_general!(array, index, BinaryArray, val, has_nulls) } ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val) + eq_array_general!(array, index, LargeBinaryArray, val, has_nulls) } ScalarValue::List(_, _) => unimplemented!(), ScalarValue::Date32(val) => { - eq_array_primitive!(array, index, Date32Array, val) + eq_array_general!(array, index, Date32Array, val, has_nulls) } ScalarValue::Date64(val) => { - eq_array_primitive!(array, index, Date64Array, val) + eq_array_general!(array, index, Date64Array, val, has_nulls) } ScalarValue::TimestampSecond(val) => { - eq_array_primitive!(array, index, TimestampSecondArray, val) + eq_array_general!(array, index, TimestampSecondArray, val, has_nulls) } ScalarValue::TimestampMillisecond(val) => { - eq_array_primitive!(array, index, TimestampMillisecondArray, val) + eq_array_general!(array, index, TimestampMillisecondArray, val, has_nulls) } ScalarValue::TimestampMicrosecond(val) => { - eq_array_primitive!(array, index, TimestampMicrosecondArray, val) + eq_array_general!(array, index, TimestampMicrosecondArray, val, has_nulls) } ScalarValue::TimestampNanosecond(val) => { - eq_array_primitive!(array, index, TimestampNanosecondArray, val) + eq_array_general!(array, index, TimestampNanosecondArray, val, has_nulls) } ScalarValue::IntervalYearMonth(val) => { - eq_array_primitive!(array, index, IntervalYearMonthArray, val) + eq_array_general!(array, index, IntervalYearMonthArray, val, has_nulls) } ScalarValue::IntervalDayTime(val) => { - eq_array_primitive!(array, index, IntervalDayTimeArray, val) + eq_array_general!(array, index, IntervalDayTimeArray, val, has_nulls) } } } @@ -1102,6 +1144,7 @@ impl ScalarValue { array: &ArrayRef, index: usize, key_type: &DataType, + has_nulls: &bool, ) -> bool { let (values, values_index) = match key_type { DataType::Int8 => get_dict_value::(array, index).unwrap(), @@ -1116,7 +1159,7 @@ impl ScalarValue { }; match values_index { - Some(values_index) => self.eq_array(values, values_index), + Some(values_index) => self.eq_array(values, values_index, has_nulls), None => self.is_null(), } } @@ -1914,7 +1957,7 @@ mod tests { for (index, scalar) in scalars.into_iter().enumerate() { assert!( - scalar.eq_array(&array, index), + scalar.eq_array(&array, index, &true), "Expected {:?} to be equal to {:?} at index {}", scalar, array, @@ -1925,7 +1968,7 @@ mod tests { for other_index in 0..array.len() { if index != other_index { assert!( - !scalar.eq_array(&array, other_index), + !scalar.eq_array(&array, other_index, &true), "Expected {:?} to be NOT equal to {:?} at index {}", scalar, array,