Skip to content

Commit a235c3b

Browse files
yaooqinnviirya
authored andcommitted
[SPARK-34037][SQL] Remove unnecessary upcasting for Avg & Sum which handle by themself internally
### What changes were proposed in this pull request? The type-coercion for numeric types of average and sum is not necessary at all, as the resultType and sumType can prevent the overflow. ### Why are the changes needed? rm unnecessary logic which may cause potential performance regressions ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? tpcds tests for plan Closes #31079 from yaooqinn/SPARK-34037. Authored-by: Kent Yao <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent c75c29d commit a235c3b

File tree

279 files changed

+1485
-1496
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

279 files changed

+1485
-1496
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -634,17 +634,6 @@ object TypeCoercion {
634634

635635
m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
636636

637-
// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
638-
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
639-
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
640-
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))
641-
642-
case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest.
643-
case Average(e @ IntegralType()) if e.dataType != LongType =>
644-
Average(Cast(e, LongType))
645-
case Average(e @ FractionalType()) if e.dataType != DoubleType =>
646-
Average(Cast(e, DoubleType))
647-
648637
// Hive lets you do aggregation of timestamps... for some reason
649638
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
650639
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))

sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,23 @@ struct<plan:string>
5555

5656
== Analyzed Logical Plan ==
5757
sum(DISTINCT val): bigint
58-
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
58+
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
5959
+- SubqueryAlias spark_catalog.default.explain_temp1
6060
+- Relation[key#x,val#x] parquet
6161

6262
== Optimized Logical Plan ==
63-
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
63+
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
6464
+- Project [val#x]
6565
+- Relation[key#x,val#x] parquet
6666

6767
== Physical Plan ==
6868
AdaptiveSparkPlan isFinalPlan=false
69-
+- HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
69+
+- HashAggregate(keys=[], functions=[sum(distinct val#x)], output=[sum(DISTINCT val)#xL])
7070
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
71-
+- HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
72-
+- HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
73-
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
74-
+- HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
71+
+- HashAggregate(keys=[], functions=[partial_sum(distinct val#x)], output=[sum#xL])
72+
+- HashAggregate(keys=[val#x], functions=[], output=[val#x])
73+
+- Exchange hashpartitioning(val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
74+
+- HashAggregate(keys=[val#x], functions=[], output=[val#x])
7575
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>
7676

7777

@@ -615,7 +615,7 @@ Input [2]: [key#x, val#x]
615615
(14) HashAggregate
616616
Input [1]: [key#x]
617617
Keys: []
618-
Functions [1]: [partial_avg(cast(key#x as bigint))]
618+
Functions [1]: [partial_avg(key#x)]
619619
Aggregate Attributes [2]: [sum#x, count#xL]
620620
Results [2]: [sum#x, count#xL]
621621

@@ -626,9 +626,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
626626
(16) HashAggregate
627627
Input [2]: [sum#x, count#xL]
628628
Keys: []
629-
Functions [1]: [avg(cast(key#x as bigint))]
630-
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
631-
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
629+
Functions [1]: [avg(key#x)]
630+
Aggregate Attributes [1]: [avg(key#x)#x]
631+
Results [1]: [avg(key#x)#x AS avg(key)#x]
632632

633633
(17) AdaptiveSparkPlan
634634
Output [1]: [avg(key)#x]
@@ -681,7 +681,7 @@ ReadSchema: struct<key:int>
681681
(5) HashAggregate
682682
Input [1]: [key#x]
683683
Keys: []
684-
Functions [1]: [partial_avg(cast(key#x as bigint))]
684+
Functions [1]: [partial_avg(key#x)]
685685
Aggregate Attributes [2]: [sum#x, count#xL]
686686
Results [2]: [sum#x, count#xL]
687687

@@ -692,9 +692,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
692692
(7) HashAggregate
693693
Input [2]: [sum#x, count#xL]
694694
Keys: []
695-
Functions [1]: [avg(cast(key#x as bigint))]
696-
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
697-
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
695+
Functions [1]: [avg(key#x)]
696+
Aggregate Attributes [1]: [avg(key#x)#x]
697+
Results [1]: [avg(key#x)#x AS avg(key)#x]
698698

699699
(8) AdaptiveSparkPlan
700700
Output [1]: [avg(key)#x]
@@ -717,7 +717,7 @@ ReadSchema: struct<key:int>
717717
(10) HashAggregate
718718
Input [1]: [key#x]
719719
Keys: []
720-
Functions [1]: [partial_avg(cast(key#x as bigint))]
720+
Functions [1]: [partial_avg(key#x)]
721721
Aggregate Attributes [2]: [sum#x, count#xL]
722722
Results [2]: [sum#x, count#xL]
723723

@@ -728,9 +728,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
728728
(12) HashAggregate
729729
Input [2]: [sum#x, count#xL]
730730
Keys: []
731-
Functions [1]: [avg(cast(key#x as bigint))]
732-
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
733-
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
731+
Functions [1]: [avg(key#x)]
732+
Aggregate Attributes [1]: [avg(key#x)#x]
733+
Results [1]: [avg(key#x)#x AS avg(key)#x]
734734

735735
(13) AdaptiveSparkPlan
736736
Output [1]: [avg(key)#x]
@@ -947,7 +947,7 @@ ReadSchema: struct<key:int,val:int>
947947
(2) HashAggregate
948948
Input [2]: [key#x, val#x]
949949
Keys: []
950-
Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]
950+
Functions [3]: [partial_count(val#x), partial_sum(key#x), partial_count(key#x) FILTER (WHERE (val#x > 1))]
951951
Aggregate Attributes [3]: [count#xL, sum#xL, count#xL]
952952
Results [3]: [count#xL, sum#xL, count#xL]
953953

@@ -958,9 +958,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
958958
(4) HashAggregate
959959
Input [3]: [count#xL, sum#xL, count#xL]
960960
Keys: []
961-
Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)]
962-
Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL]
963-
Results [2]: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]
961+
Functions [3]: [count(val#x), sum(key#x), count(key#x)]
962+
Aggregate Attributes [3]: [count(val#x)#xL, sum(key#x)#xL, count(key#x)#xL]
963+
Results [2]: [(count(val#x)#xL + sum(key#x)#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]
964964

965965
(5) AdaptiveSparkPlan
966966
Output [2]: [TOTAL#xL, count(key) FILTER (WHERE (val > 1))#xL]

sql/core/src/test/resources/sql-tests/results/explain.sql.out

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,22 @@ struct<plan:string>
5555

5656
== Analyzed Logical Plan ==
5757
sum(DISTINCT val): bigint
58-
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
58+
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
5959
+- SubqueryAlias spark_catalog.default.explain_temp1
6060
+- Relation[key#x,val#x] parquet
6161

6262
== Optimized Logical Plan ==
63-
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
63+
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
6464
+- Project [val#x]
6565
+- Relation[key#x,val#x] parquet
6666

6767
== Physical Plan ==
68-
*HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
68+
*HashAggregate(keys=[], functions=[sum(distinct val#x)], output=[sum(DISTINCT val)#xL])
6969
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
70-
+- *HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
71-
+- *HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
72-
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
73-
+- *HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
70+
+- *HashAggregate(keys=[], functions=[partial_sum(distinct val#x)], output=[sum#xL])
71+
+- *HashAggregate(keys=[val#x], functions=[], output=[val#x])
72+
+- Exchange hashpartitioning(val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
73+
+- *HashAggregate(keys=[val#x], functions=[], output=[val#x])
7474
+- *ColumnarToRow
7575
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>
7676

@@ -620,7 +620,7 @@ Input [2]: [key#x, val#x]
620620
(15) HashAggregate [codegen id : 1]
621621
Input [1]: [key#x]
622622
Keys: []
623-
Functions [1]: [partial_avg(cast(key#x as bigint))]
623+
Functions [1]: [partial_avg(key#x)]
624624
Aggregate Attributes [2]: [sum#x, count#xL]
625625
Results [2]: [sum#x, count#xL]
626626

@@ -631,9 +631,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
631631
(17) HashAggregate [codegen id : 2]
632632
Input [2]: [sum#x, count#xL]
633633
Keys: []
634-
Functions [1]: [avg(cast(key#x as bigint))]
635-
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
636-
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
634+
Functions [1]: [avg(key#x)]
635+
Aggregate Attributes [1]: [avg(key#x)#x]
636+
Results [1]: [avg(key#x)#x AS avg(key)#x]
637637

638638

639639
-- !query
@@ -684,7 +684,7 @@ Input [1]: [key#x]
684684
(6) HashAggregate [codegen id : 1]
685685
Input [1]: [key#x]
686686
Keys: []
687-
Functions [1]: [partial_avg(cast(key#x as bigint))]
687+
Functions [1]: [partial_avg(key#x)]
688688
Aggregate Attributes [2]: [sum#x, count#xL]
689689
Results [2]: [sum#x, count#xL]
690690

@@ -695,9 +695,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
695695
(8) HashAggregate [codegen id : 2]
696696
Input [2]: [sum#x, count#xL]
697697
Keys: []
698-
Functions [1]: [avg(cast(key#x as bigint))]
699-
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
700-
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
698+
Functions [1]: [avg(key#x)]
699+
Aggregate Attributes [1]: [avg(key#x)#x]
700+
Results [1]: [avg(key#x)#x AS avg(key)#x]
701701

702702
Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x]
703703

@@ -895,7 +895,7 @@ Input [2]: [key#x, val#x]
895895
(3) HashAggregate [codegen id : 1]
896896
Input [2]: [key#x, val#x]
897897
Keys: []
898-
Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]
898+
Functions [3]: [partial_count(val#x), partial_sum(key#x), partial_count(key#x) FILTER (WHERE (val#x > 1))]
899899
Aggregate Attributes [3]: [count#xL, sum#xL, count#xL]
900900
Results [3]: [count#xL, sum#xL, count#xL]
901901

@@ -906,9 +906,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
906906
(5) HashAggregate [codegen id : 2]
907907
Input [3]: [count#xL, sum#xL, count#xL]
908908
Keys: []
909-
Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)]
910-
Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL]
911-
Results [2]: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]
909+
Functions [3]: [count(val#x), sum(key#x), count(key#x)]
910+
Aggregate Attributes [3]: [count(val#x)#xL, sum(key#x)#xL, count(key#x)#xL]
911+
Results [2]: [(count(val#x)#xL + sum(key#x)#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]
912912

913913

914914
-- !query

sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3
122122
struct<>
123123
-- !query output
124124
org.apache.spark.sql.AnalysisException
125-
aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT))
125+
aggregate functions are not allowed in GROUP BY, but found sum(data.`b`)
126126

127127

128128
-- !query
@@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3
131131
struct<>
132132
-- !query output
133133
org.apache.spark.sql.AnalysisException
134-
aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT))
134+
aggregate functions are not allowed in GROUP BY, but found (sum(data.`b`) + CAST(2 AS BIGINT))
135135

136136

137137
-- !query

sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ struct<>
381381
org.apache.spark.sql.AnalysisException
382382

383383
Aggregate/Window/Generate expressions are not valid in where clause of the query.
384-
Expression in where clause: [(sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT)) = CAST(b.`four` AS BIGINT))]
385-
Invalid expressions: [sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT))]
384+
Expression in where clause: [(sum(DISTINCT (outer(a.`four`) + b.`four`)) = CAST(b.`four` AS BIGINT))]
385+
Invalid expressions: [sum(DISTINCT (outer(a.`four`) + b.`four`))]
386386

387387

388388
-- !query

sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ AND t2b = (SELECT max(avg)
4646
struct<>
4747
-- !query output
4848
org.apache.spark.sql.AnalysisException
49-
grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(CAST(t2.`t2b` AS BIGINT)) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.
49+
grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(t2.`t2b`) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.
5050

5151

5252
-- !query

sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,8 @@ struct<>
372372
org.apache.spark.sql.AnalysisException
373373

374374
Aggregate/Window/Generate expressions are not valid in where clause of the query.
375-
Expression in where clause: [(sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT)) = CAST(CAST(udf(ansi_cast(four as string)) AS INT) AS BIGINT))]
376-
Invalid expressions: [sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT))]
375+
Expression in where clause: [(sum(DISTINCT (outer(a.`four`) + b.`four`)) = CAST(CAST(udf(ansi_cast(four as string)) AS INT) AS BIGINT))]
376+
Invalid expressions: [sum(DISTINCT (outer(a.`four`) + b.`four`))]
377377

378378

379379
-- !query

sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q19.sf100/explain.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ Results [5]: [i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, sum#27
206206

207207
(37) Exchange
208208
Input [5]: [i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, sum#27]
209-
Arguments: hashpartitioning(i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, 5), true, [id=#28]
209+
Arguments: hashpartitioning(i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, 5), ENSURE_REQUIREMENTS, [id=#28]
210210

211211
(38) HashAggregate [codegen id : 7]
212212
Input [5]: [i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, sum#27]

sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q19/explain.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ Results [5]: [i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, sum#27
206206

207207
(37) Exchange
208208
Input [5]: [i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, sum#27]
209-
Arguments: hashpartitioning(i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, 5), true, [id=#28]
209+
Arguments: hashpartitioning(i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, 5), ENSURE_REQUIREMENTS, [id=#28]
210210

211211
(38) HashAggregate [codegen id : 7]
212212
Input [5]: [i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, sum#27]

0 commit comments

Comments
 (0)