Skip to content

Commit

Permalink
Add regr_*() aggregate functions (#7211)
Browse files Browse the repository at this point in the history
  • Loading branch information
2010YOUY01 committed Aug 8, 2023
1 parent 85c9fe1 commit d1361d5
Show file tree
Hide file tree
Showing 13 changed files with 728 additions and 150 deletions.
331 changes: 221 additions & 110 deletions datafusion/core/tests/sqllogictests/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2294,157 +2294,268 @@ NULL


#
# regr_slope() tests
# regr_*() tests
#

# invalid input
# regr_*() invalid input
statement error
select regr_slope();

statement error
select regr_slope(*);
select regr_intercept(*);

statement error
select regr_slope(*) from aggregate_test_100;
select regr_count(*) from aggregate_test_100;

statement error
select regr_slope(1);
select regr_r2(1);

statement error
select regr_slope(1,2,3);
select regr_avgx(1,2,3);

statement error
select regr_slope(1, 'foo');
select regr_avgy(1, 'foo');

statement error
select regr_slope('foo', 1);
select regr_sxx('foo', 1);

statement error
select regr_slope('foo', 'bar');
select regr_syy('foo', 'bar');

statement error
select regr_sxy(NULL, 'bar');


# regr_slope() NULL result
query R
select regr_slope(1,1);
----
NULL

query R
select regr_slope(1, NULL);
----
NULL

query R
select regr_slope(NULL, 1);
----
NULL

query R
select regr_slope(NULL, NULL);
----
NULL

query R
select regr_slope(column2, column1) from (values (1,2), (1,4), (1,6));
----
NULL



# regr_slope() basic tests
query R
select regr_slope(column2, column1) from (values (1,2), (2,4), (3,6));
----
2

query R
select regr_slope(c12, c11) from aggregate_test_100;
----
0.051534002628



# regr_slope() ignore NULLs
query R
select regr_slope(column2, column1) from (values (1,NULL), (2,4), (3,6));
----
2

query R
select regr_slope(column2, column1) from (values (1,NULL), (NULL,4), (3,6));
----
NULL

query R
select regr_slope(column2, column1) from (values (1,NULL), (NULL,4), (NULL,NULL));
----
NULL

query TR rowsort
select column3, regr_slope(column2, column1)

# regr_*() NULL results
query RRRRRRRRR
select regr_slope(1,1), regr_intercept(1,1), regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), regr_sxy(1,1);
----
NULL NULL 1 NULL 1 1 0 0 0

query RRRRRRRRR
select regr_slope(1, NULL), regr_intercept(1, NULL), regr_count(1, NULL), regr_r2(1, NULL), regr_avgx(1, NULL), regr_avgy(1, NULL), regr_sxx(1, NULL), regr_syy(1, NULL), regr_sxy(1, NULL);
----
NULL NULL 0 NULL NULL NULL NULL NULL NULL

query RRRRRRRRR
select regr_slope(NULL, 1), regr_intercept(NULL, 1), regr_count(NULL, 1), regr_r2(NULL, 1), regr_avgx(NULL, 1), regr_avgy(NULL, 1), regr_sxx(NULL, 1), regr_syy(NULL, 1), regr_sxy(NULL, 1);
----
NULL NULL 0 NULL NULL NULL NULL NULL NULL

query RRRRRRRRR
select regr_slope(NULL, NULL), regr_intercept(NULL, NULL), regr_count(NULL, NULL), regr_r2(NULL, NULL), regr_avgx(NULL, NULL), regr_avgy(NULL, NULL), regr_sxx(NULL, NULL), regr_syy(NULL, NULL), regr_sxy(NULL, NULL);
----
NULL NULL 0 NULL NULL NULL NULL NULL NULL

query RRRRRRRRR
select regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), regr_r2(column2, column1), regr_avgx(column2, column1), regr_avgy(column2, column1), regr_sxx(column2, column1), regr_syy(column2, column1), regr_sxy(column2, column1) from (values (1,2), (1,4), (1,6));
----
NULL NULL 3 NULL 1 4 0 8 0



# regr_*() basic tests
query RRRRRRRRR
select
regr_slope(column2, column1),
regr_intercept(column2, column1),
regr_count(column2, column1),
regr_r2(column2, column1),
regr_avgx(column2, column1),
regr_avgy(column2, column1),
regr_sxx(column2, column1),
regr_syy(column2, column1),
regr_sxy(column2, column1)
from (values (1,2), (2,4), (3,6));
----
2 0 3 1 2 4 2 8 4

query RRRRRRRRR
select
regr_slope(c12, c11),
regr_intercept(c12, c11),
regr_count(c12, c11),
regr_r2(c12, c11),
regr_avgx(c12, c11),
regr_avgy(c12, c11),
regr_sxx(c12, c11),
regr_syy(c12, c11),
regr_sxy(c12, c11)
from aggregate_test_100;
----
0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695



# regr_*() functions ignore NULLs
query RRRRRRRRR
select
regr_slope(column2, column1),
regr_intercept(column2, column1),
regr_count(column2, column1),
regr_r2(column2, column1),
regr_avgx(column2, column1),
regr_avgy(column2, column1),
regr_sxx(column2, column1),
regr_syy(column2, column1),
regr_sxy(column2, column1)
from (values (1,NULL), (2,4), (3,6));
----
2 0 2 1 2.5 5 0.5 2 1

query RRRRRRRRR
select
regr_slope(column2, column1),
regr_intercept(column2, column1),
regr_count(column2, column1),
regr_r2(column2, column1),
regr_avgx(column2, column1),
regr_avgy(column2, column1),
regr_sxx(column2, column1),
regr_syy(column2, column1),
regr_sxy(column2, column1)
from (values (1,NULL), (NULL,4), (3,6));
----
NULL NULL 1 NULL 3 6 0 0 0

query RRRRRRRRR
select
regr_slope(column2, column1),
regr_intercept(column2, column1),
regr_count(column2, column1),
regr_r2(column2, column1),
regr_avgx(column2, column1),
regr_avgy(column2, column1),
regr_sxx(column2, column1),
regr_syy(column2, column1),
regr_sxy(column2, column1)
from (values (1,NULL), (NULL,4), (NULL,NULL));
----
NULL NULL 0 NULL NULL NULL NULL NULL NULL

query TRRRRRRRRR rowsort
select
column3,
regr_slope(column2, column1),
regr_intercept(column2, column1),
regr_count(column2, column1),
regr_r2(column2, column1),
regr_avgx(column2, column1),
regr_avgy(column2, column1),
regr_sxx(column2, column1),
regr_syy(column2, column1),
regr_sxy(column2, column1)
from (values (1,2,'a'), (2,4,'a'), (1,3,'b'), (3,9,'b'), (1,10,'c'), (NULL,100,'c'))
group by column3;
----
a 2
b 3
c NULL
a 2 0 2 1 1.5 3 0.5 2 1
b 3 0 2 1 2 6 2 18 6
c NULL NULL 1 NULL 1 10 0 0 0



# regr_slope() testing merge_batch() from RegrSlopeAccumulator's internal implementation
# regr_*() testing merge_batch() from RegrAccumulator's internal implementation
statement ok
set datafusion.execution.batch_size = 1;

query R
select regr_slope(c12, c11) from aggregate_test_100;
----
0.051534002628
query RRRRRRRRR
select
regr_slope(c12, c11),
regr_intercept(c12, c11),
regr_count(c12, c11),
regr_r2(c12, c11),
regr_avgx(c12, c11),
regr_avgy(c12, c11),
regr_sxx(c12, c11),
regr_syy(c12, c11),
regr_sxy(c12, c11)
from aggregate_test_100;
----
0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695

statement ok
set datafusion.execution.batch_size = 2;

query R
select regr_slope(c12, c11) from aggregate_test_100;
----
0.051534002628
query RRRRRRRRR
select
regr_slope(c12, c11),
regr_intercept(c12, c11),
regr_count(c12, c11),
regr_r2(c12, c11),
regr_avgx(c12, c11),
regr_avgy(c12, c11),
regr_sxx(c12, c11),
regr_syy(c12, c11),
regr_sxy(c12, c11)
from aggregate_test_100;
----
0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695

statement ok
set datafusion.execution.batch_size = 3;

query R
select regr_slope(c12, c11) from aggregate_test_100;
----
0.051534002628
query RRRRRRRRR
select
regr_slope(c12, c11),
regr_intercept(c12, c11),
regr_count(c12, c11),
regr_r2(c12, c11),
regr_avgx(c12, c11),
regr_avgy(c12, c11),
regr_sxx(c12, c11),
regr_syy(c12, c11),
regr_sxy(c12, c11)
from aggregate_test_100;
----
0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695

statement ok
set datafusion.execution.batch_size = 8192;



# regr_slope testing retract_batch() from RegrSlopeAccumulator's internal implementation
query R
select regr_slope(column2, column1)
over (order by column1 rows between 2 preceding and current row)
from (values (1,2), (2,4), (3,6), (4,12), (5,15), (6, 18));
----
NULL
2
2
4
4.5
3

query R
select regr_slope(column2, column1)
over (order by column1 rows between 2 preceding and current row)
from (values (1,2), (2,4), (3,6), (3, NULL), (4, NULL), (5,15), (6,18), (7, 21));
----
NULL
2
2
2
NULL
NULL
3
3
# regr_*() testing retract_batch() from RegrAccumulator's internal implementation
query RRRRRRRRR
SELECT
regr_slope(column2, column1) OVER w AS slope,
regr_intercept(column2, column1) OVER w AS intercept,
regr_count(column2, column1) OVER w AS count,
regr_r2(column2, column1) OVER w AS r2,
regr_avgx(column2, column1) OVER w AS avgx,
regr_avgy(column2, column1) OVER w AS avgy,
regr_sxx(column2, column1) OVER w AS sxx,
regr_syy(column2, column1) OVER w AS syy,
regr_sxy(column2, column1) OVER w AS sxy
FROM (VALUES (1,2), (2,4), (3,6), (4,12), (5,15), (6,18)) AS t(column1, column2)
WINDOW w AS (ORDER BY column1 ROWS BETWEEN 2 PRECEDING AND CURRENT ROW);
----
NULL NULL 1 NULL 1 2 0 0 0
2 0 2 1 1.5 3 0.5 2 1
2 0 3 1 2 4 2 8 4
4 -4.666666666667 3 0.923076923077 3 7.333333333333 2 34.666666666667 8
4.5 -7 3 0.964285714286 4 11 2 42 9
3 0 3 1 5 15 2 18 6

query RRRRRRRRR
SELECT
regr_slope(column2, column1) OVER w AS slope,
regr_intercept(column2, column1) OVER w AS intercept,
regr_count(column2, column1) OVER w AS count,
regr_r2(column2, column1) OVER w AS r2,
regr_avgx(column2, column1) OVER w AS avgx,
regr_avgy(column2, column1) OVER w AS avgy,
regr_sxx(column2, column1) OVER w AS sxx,
regr_syy(column2, column1) OVER w AS syy,
regr_sxy(column2, column1) OVER w AS sxy
FROM (VALUES (1,2), (2,4), (3,6), (3, NULL), (4, NULL), (5,15), (6,18), (7, 21)) AS t(column1, column2)
WINDOW w AS (ORDER BY column1 ROWS BETWEEN 2 PRECEDING AND CURRENT ROW);
----
NULL NULL 1 NULL 1 2 0 0 0
2 0 2 1 1.5 3 0.5 2 1
2 0 3 1 2 4 2 8 4
2 0 2 1 2.5 5 0.5 2 1
NULL NULL 1 NULL 3 6 0 0 0
NULL NULL 1 NULL 5 15 0 0 0
3 0 2 1 5.5 16.5 0.5 4.5 1.5
3 0 3 1 6 18 2 18 6
Loading

0 comments on commit d1361d5

Please sign in to comment.