Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support min/max for Float16 type #12050

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ datafusion-expr = { workspace = true }
datafusion-functions-aggregate-common = { workspace = true }
datafusion-physical-expr = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
half = { workspace = true }
log = { workspace = true }
paste = "1.0.14"
sqlparser = { workspace = true }
Expand Down
34 changes: 25 additions & 9 deletions datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,19 @@

use arrow::array::{
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray,
StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array,
Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray,
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::compute;
use arrow::datatypes::{
DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type,
Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type,
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type,
UInt8Type,
};
use arrow_schema::IntervalUnit;
use datafusion_common::{
Expand All @@ -66,6 +67,7 @@ use datafusion_expr::GroupsAccumulator;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use half::f16;
use std::ops::Deref;

fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
Expand Down Expand Up @@ -181,6 +183,7 @@ impl AggregateUDFImpl for Max {
| UInt16
| UInt32
| UInt64
| Float16
| Float32
| Float64
| Decimal128(_, _)
Expand Down Expand Up @@ -209,6 +212,9 @@ impl AggregateUDFImpl for Max {
UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type),
UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type),
UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type),
Float16 => {
instantiate_max_accumulator!(data_type, f16, Float16Type)
}
Float32 => {
instantiate_max_accumulator!(data_type, f32, Float32Type)
}
Expand Down Expand Up @@ -339,6 +345,9 @@ macro_rules! min_max_batch {
DataType::Float32 => {
typed_min_max_batch!($VALUES, Float32Array, Float32, $OP)
}
DataType::Float16 => {
typed_min_max_batch!($VALUES, Float16Array, Float16, $OP)
}
DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP),
DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP),
DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP),
Expand Down Expand Up @@ -623,6 +632,9 @@ macro_rules! min_max {
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
typed_min_max_float!(lhs, rhs, Float32, $OP)
}
(ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => {
typed_min_max_float!(lhs, rhs, Float16, $OP)
}
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
typed_min_max!(lhs, rhs, UInt64, $OP)
}
Expand Down Expand Up @@ -950,6 +962,7 @@ impl AggregateUDFImpl for Min {
| UInt16
| UInt32
| UInt64
| Float16
| Float32
| Float64
| Decimal128(_, _)
Expand Down Expand Up @@ -978,6 +991,9 @@ impl AggregateUDFImpl for Min {
UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type),
UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type),
UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type),
Float16 => {
instantiate_min_accumulator!(data_type, f16, Float16Type)
}
Float32 => {
instantiate_min_accumulator!(data_type, f32, Float32Type)
}
Expand Down
28 changes: 28 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5643,3 +5643,31 @@ query I??III?T
select count(null), min(null), max(null), bit_and(NULL), bit_or(NULL), bit_xor(NULL), nth_value(NULL, 1), string_agg(NULL, ',');
----
0 NULL NULL NULL NULL NULL NULL NULL

# test min/max Float16 without group expression
query RRTT
WITH data AS (
SELECT arrow_cast(1, 'Float16') AS f
UNION ALL
SELECT arrow_cast(6, 'Float16') AS f
)
SELECT MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f)) FROM data;
----
1 6 Float16 Float16

# test min/max Float16 with group expression
query IRRTT
WITH data AS (
SELECT 1 as k, arrow_cast(1.8125, 'Float16') AS f
UNION ALL
SELECT 1 as k, arrow_cast(6.8007813, 'Float16') AS f
UNION ALL
SELECT 2 AS k, arrow_cast(8.5, 'Float16') AS f
)
SELECT k, MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f))
FROM data
GROUP BY k
ORDER BY k;
----
1 1.8125 6.8007813 Float16 Float16
2 8.5 8.5 Float16 Float16
16 changes: 7 additions & 9 deletions datafusion/sqllogictest/test_files/arrow_typeof.slt
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ query error Error unrecognized word: unknown
SELECT arrow_cast('1', 'unknown')

# Round Trip tests:
query TTTTTTTTTTTTTTTTTTTTTTT
query TTTTTTTTTTTTTTTTTTTTTTTT
SELECT
arrow_typeof(arrow_cast(1, 'Int8')) as col_i8,
arrow_typeof(arrow_cast(1, 'Int16')) as col_i16,
Expand All @@ -112,8 +112,7 @@ SELECT
arrow_typeof(arrow_cast(1, 'UInt16')) as col_u16,
arrow_typeof(arrow_cast(1, 'UInt32')) as col_u32,
arrow_typeof(arrow_cast(1, 'UInt64')) as col_u64,
-- can't seem to cast to Float16 for some reason
-- arrow_typeof(arrow_cast(1, 'Float16')) as col_f16,
arrow_typeof(arrow_cast(1, 'Float16')) as col_f16,
arrow_typeof(arrow_cast(1, 'Float32')) as col_f32,
arrow_typeof(arrow_cast(1, 'Float64')) as col_f64,
arrow_typeof(arrow_cast('foo', 'Utf8')) as col_utf8,
Expand All @@ -130,7 +129,7 @@ SELECT
arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, Some("+08:00"))')) as col_tstz_ns,
arrow_typeof(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) as col_dict
----
Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 Utf8 LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) Timestamp(Second, Some("+08:00")) Timestamp(Millisecond, Some("+08:00")) Timestamp(Microsecond, Some("+08:00")) Timestamp(Nanosecond, Some("+08:00")) Dictionary(Int32, Utf8)
Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float16 Float32 Float64 Utf8 LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) Timestamp(Second, Some("+08:00")) Timestamp(Millisecond, Some("+08:00")) Timestamp(Microsecond, Some("+08:00")) Timestamp(Nanosecond, Some("+08:00")) Dictionary(Int32, Utf8)



Expand All @@ -147,15 +146,14 @@ create table foo as select
arrow_cast(1, 'UInt16') as col_u16,
arrow_cast(1, 'UInt32') as col_u32,
arrow_cast(1, 'UInt64') as col_u64,
-- can't seem to cast to Float16 for some reason
-- arrow_cast(1.0, 'Float16') as col_f16,
arrow_cast(1.0, 'Float16') as col_f16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

arrow_cast(1.0, 'Float32') as col_f32,
arrow_cast(1.0, 'Float64') as col_f64
;

## Ensure each column in the table has the expected type

query TTTTTTTTTT
query TTTTTTTTTTT
SELECT
arrow_typeof(col_i8),
arrow_typeof(col_i16),
Expand All @@ -165,12 +163,12 @@ SELECT
arrow_typeof(col_u16),
arrow_typeof(col_u32),
arrow_typeof(col_u64),
-- arrow_typeof(col_f16),
arrow_typeof(col_f16),
arrow_typeof(col_f32),
arrow_typeof(col_f64)
FROM foo;
----
Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64
Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float16 Float32 Float64


statement ok
Expand Down