Skip to content

Commit

Permalink
Fix DistinctCount for timestamps with time zone
Browse files Browse the repository at this point in the history
Preserve the original data type in the aggregation state
  • Loading branch information
joroKr21 committed Apr 11, 2024
1 parent b9759b9 commit cb59fbd
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 26 deletions.
42 changes: 24 additions & 18 deletions datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,14 @@ impl AggregateExpr for DistinctCount {
UInt16 => Box::new(PrimitiveDistinctCountAccumulator::<UInt16Type>::new()),
UInt32 => Box::new(PrimitiveDistinctCountAccumulator::<UInt32Type>::new()),
UInt64 => Box::new(PrimitiveDistinctCountAccumulator::<UInt64Type>::new()),
Decimal128(_, _) => {
Box::new(PrimitiveDistinctCountAccumulator::<Decimal128Type>::new())
}
Decimal256(_, _) => {
Box::new(PrimitiveDistinctCountAccumulator::<Decimal256Type>::new())
}
dt @ Decimal128(_, _) => Box::new(
PrimitiveDistinctCountAccumulator::<Decimal128Type>::new()
.with_data_type(dt.clone()),
),
dt @ Decimal256(_, _) => Box::new(
PrimitiveDistinctCountAccumulator::<Decimal256Type>::new()
.with_data_type(dt.clone()),
),

Date32 => Box::new(PrimitiveDistinctCountAccumulator::<Date32Type>::new()),
Date64 => Box::new(PrimitiveDistinctCountAccumulator::<Date64Type>::new()),
Expand All @@ -130,18 +132,22 @@ impl AggregateExpr for DistinctCount {
Time64(Nanosecond) => {
Box::new(PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new())
}
Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampMicrosecondType,
>::new()),
Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampMillisecondType,
>::new()),
Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampNanosecondType,
>::new()),
Timestamp(Second, _) => {
Box::new(PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new())
}
dt @ Timestamp(Microsecond, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new()
.with_data_type(dt.clone()),
),
dt @ Timestamp(Millisecond, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new()
.with_data_type(dt.clone()),
),
dt @ Timestamp(Nanosecond, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new()
.with_data_type(dt.clone()),
),
dt @ Timestamp(Second, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new()
.with_data_type(dt.clone()),
),

Float16 => Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()),
Float32 => Box::new(FloatDistinctCountAccumulator::<Float32Type>::new()),
Expand Down
15 changes: 12 additions & 3 deletions datafusion/physical-expr/src/aggregate/count_distinct/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use ahash::RandomState;
use arrow::array::ArrayRef;
use arrow_array::types::ArrowPrimitiveType;
use arrow_array::PrimitiveArray;
use arrow_schema::DataType;

use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
Expand All @@ -45,6 +46,7 @@ where
T::Native: Eq + Hash,
{
values: HashSet<T::Native, RandomState>,
data_type: DataType,
}

impl<T> PrimitiveDistinctCountAccumulator<T>
Expand All @@ -55,8 +57,14 @@ where
pub(super) fn new() -> Self {
Self {
values: HashSet::default(),
data_type: T::DATA_TYPE,
}
}

pub(super) fn with_data_type(mut self, data_type: DataType) -> Self {
self.data_type = data_type;
self
}
}

impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T>
Expand All @@ -65,9 +73,10 @@ where
T::Native: Eq + Hash,
{
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().cloned(),
)) as ArrayRef;
let arr = Arc::new(
PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
.with_data_type(self.data_type.clone()),
);
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}
Expand Down
25 changes: 20 additions & 5 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1876,18 +1876,22 @@ select
arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros,
arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis,
arrow_cast(column1, 'Timestamp(Second, None)') as secs,
arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as nanos_utc,
arrow_cast(column1, 'Timestamp(Microsecond, Some("UTC"))') as micros_utc,
arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as millis_utc,
arrow_cast(column1, 'Timestamp(Second, Some("UTC"))') as secs_utc,
column2 as names,
column3 as tag
from t_source;

# Demonstate the contents
query PPPPTT
query PPPPPPPPTT
select * from t;
----
2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 X
2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 X
NULL NULL NULL NULL Row 2 Y
2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 Y
2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 X
2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 X
NULL NULL NULL NULL NULL NULL NULL NULL Row 2 Y
2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 Y


# aggregate_timestamps_sum
Expand Down Expand Up @@ -1933,6 +1937,17 @@ SELECT tag, max(nanos), max(micros), max(millis), max(secs) FROM t GROUP BY tag
X 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10
Y 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10

# aggregate_timestamps_count_distinct_with_tz
query IIII
SELECT count(DISTINCT nanos_utc), count(DISTINCT micros_utc), count(DISTINCT millis_utc), count(DISTINCT secs_utc) FROM t;
----
3 3 3 3

query TIIII
SELECT tag, count(DISTINCT nanos_utc), count(DISTINCT micros_utc), count(DISTINCT millis_utc), count(DISTINCT secs_utc) FROM t GROUP BY tag ORDER BY tag;
----
X 2 2 2 2
Y 1 1 1 1

# aggregate_timestamps_avg
statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\.
Expand Down

0 comments on commit cb59fbd

Please sign in to comment.