Skip to content

Commit af92299

Browse files
dongjoon-hyunhvanhovell
authored andcommitted
[SPARK-14664][SQL] Implement DecimalAggregates optimization for Window queries
## What changes were proposed in this pull request? This PR aims to implement decimal aggregation optimization for window queries by improving existing `DecimalAggregates`. Historically, `DecimalAggregates` optimizer is designed to transform general `sum/avg(decimal)`, but it breaks recently added windows queries like the followings. The following queries work well without the current `DecimalAggregates` optimizer. **Sum** ```scala scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").head java.lang.RuntimeException: Unsupported window function: MakeDecimal((sum(UnscaledValue(a#31)),mode=Complete,isDistinct=false),12,1) scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#23] : +- INPUT +- Window [MakeDecimal((sum(UnscaledValue(a#21)),mode=Complete,isDistinct=false),12,1) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#23] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#21] +- Scan OneRowRelation[] ``` **Average** ```scala scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").head java.lang.RuntimeException: Unsupported window function: cast(((avg(UnscaledValue(a#40)),mode=Complete,isDistinct=false) / 10.0) as decimal(6,5)) scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#44] : +- INPUT +- Window [cast(((avg(UnscaledValue(a#42)),mode=Complete,isDistinct=false) / 10.0) as decimal(6,5)) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#44] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#42] +- Scan OneRowRelation[] ``` After this PR, those queries work fine and new optimized physical plans look like the followings. **Sum** ```scala scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#35] : +- INPUT +- Window [MakeDecimal((sum(UnscaledValue(a#33)),mode=Complete,isDistinct=false) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),12,1) AS sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#35] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#33] +- Scan OneRowRelation[] ``` **Average** ```scala scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#47] : +- INPUT +- Window [cast(((avg(UnscaledValue(a#45)),mode=Complete,isDistinct=false) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) / 10.0) as decimal(6,5)) AS avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#47] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#45] +- Scan OneRowRelation[] ``` In this PR, *SUM over window* pattern matching is based on the code of hvanhovell ; he should be credited for the work he did. ## How was this patch tested? Pass the Jenkins tests (with newly added testcases) Author: Dongjoon Hyun <[email protected]> Closes #12421 from dongjoon-hyun/SPARK-14664.
1 parent c74fd1e commit af92299

File tree

3 files changed

+161
-12
lines changed

3 files changed

+161
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,17 +1343,35 @@ object DecimalAggregates extends Rule[LogicalPlan] {
13431343
/** Maximum number of decimal digits representable precisely in a Double */
13441344
private val MAX_DOUBLE_DIGITS = 15
13451345

1346-
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
1347-
case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
1348-
if prec + 10 <= MAX_LONG_DIGITS =>
1349-
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
1350-
1351-
case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
1352-
if prec + 4 <= MAX_DOUBLE_DIGITS =>
1353-
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
1354-
Cast(
1355-
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
1356-
DecimalType(prec + 4, scale + 4))
1346+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1347+
case q: LogicalPlan => q transformExpressionsDown {
1348+
case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _), _) => af match {
1349+
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
1350+
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
1351+
prec + 10, scale)
1352+
1353+
case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
1354+
val newAggExpr =
1355+
we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
1356+
Cast(
1357+
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
1358+
DecimalType(prec + 4, scale + 4))
1359+
1360+
case _ => we
1361+
}
1362+
case ae @ AggregateExpression(af, _, _, _) => af match {
1363+
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
1364+
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
1365+
1366+
case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
1367+
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
1368+
Cast(
1369+
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
1370+
DecimalType(prec + 4, scale + 4))
1371+
1372+
case _ => ae
1373+
}
1374+
}
13571375
}
13581376
}
13591377

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.dsl.plans._
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.plans.PlanTest
24+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
25+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
26+
import org.apache.spark.sql.types.DecimalType
27+
28+
class DecimalAggregatesSuite extends PlanTest {
29+
30+
object Optimize extends RuleExecutor[LogicalPlan] {
31+
val batches = Batch("Decimal Optimizations", FixedPoint(100),
32+
DecimalAggregates) :: Nil
33+
}
34+
35+
val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1))
36+
37+
test("Decimal Sum Aggregation: Optimized") {
38+
val originalQuery = testRelation.select(sum('a))
39+
val optimized = Optimize.execute(originalQuery.analyze)
40+
val correctAnswer = testRelation
41+
.select(MakeDecimal(sum(UnscaledValue('a)), 12, 1).as("sum(a)")).analyze
42+
43+
comparePlans(optimized, correctAnswer)
44+
}
45+
46+
test("Decimal Sum Aggregation: Not Optimized") {
47+
val originalQuery = testRelation.select(sum('b))
48+
val optimized = Optimize.execute(originalQuery.analyze)
49+
val correctAnswer = originalQuery.analyze
50+
51+
comparePlans(optimized, correctAnswer)
52+
}
53+
54+
test("Decimal Average Aggregation: Optimized") {
55+
val originalQuery = testRelation.select(avg('a))
56+
val optimized = Optimize.execute(originalQuery.analyze)
57+
val correctAnswer = testRelation
58+
.select((avg(UnscaledValue('a)) / 10.0).cast(DecimalType(6, 5)).as("avg(a)")).analyze
59+
60+
comparePlans(optimized, correctAnswer)
61+
}
62+
63+
test("Decimal Average Aggregation: Not Optimized") {
64+
val originalQuery = testRelation.select(avg('b))
65+
val optimized = Optimize.execute(originalQuery.analyze)
66+
val correctAnswer = originalQuery.analyze
67+
68+
comparePlans(optimized, correctAnswer)
69+
}
70+
71+
test("Decimal Sum Aggregation over Window: Optimized") {
72+
val spec = windowSpec(Seq('a), Nil, UnspecifiedFrame)
73+
val originalQuery = testRelation.select(windowExpr(sum('a), spec).as('sum_a))
74+
val optimized = Optimize.execute(originalQuery.analyze)
75+
val correctAnswer = testRelation
76+
.select('a)
77+
.window(
78+
Seq(MakeDecimal(windowExpr(sum(UnscaledValue('a)), spec), 12, 1).as('sum_a)),
79+
Seq('a),
80+
Nil)
81+
.select('a, 'sum_a, 'sum_a)
82+
.select('sum_a)
83+
.analyze
84+
85+
comparePlans(optimized, correctAnswer)
86+
}
87+
88+
test("Decimal Sum Aggregation over Window: Not Optimized") {
89+
val spec = windowSpec('b :: Nil, Nil, UnspecifiedFrame)
90+
val originalQuery = testRelation.select(windowExpr(sum('b), spec))
91+
val optimized = Optimize.execute(originalQuery.analyze)
92+
val correctAnswer = originalQuery.analyze
93+
94+
comparePlans(optimized, correctAnswer)
95+
}
96+
97+
test("Decimal Average Aggregation over Window: Optimized") {
98+
val spec = windowSpec(Seq('a), Nil, UnspecifiedFrame)
99+
val originalQuery = testRelation.select(windowExpr(avg('a), spec).as('avg_a))
100+
val optimized = Optimize.execute(originalQuery.analyze)
101+
val correctAnswer = testRelation
102+
.select('a)
103+
.window(
104+
Seq((windowExpr(avg(UnscaledValue('a)), spec) / 10.0).cast(DecimalType(6, 5)).as('avg_a)),
105+
Seq('a),
106+
Nil)
107+
.select('a, 'avg_a, 'avg_a)
108+
.select('avg_a)
109+
.analyze
110+
111+
comparePlans(optimized, correctAnswer)
112+
}
113+
114+
test("Decimal Average Aggregation over Window: Not Optimized") {
115+
val spec = windowSpec('b :: Nil, Nil, UnspecifiedFrame)
116+
val originalQuery = testRelation.select(windowExpr(avg('b), spec))
117+
val optimized = Optimize.execute(originalQuery.analyze)
118+
val correctAnswer = originalQuery.analyze
119+
120+
comparePlans(optimized, correctAnswer)
121+
}
122+
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.functions._
2222
import org.apache.spark.sql.internal.SQLConf
2323
import org.apache.spark.sql.test.SharedSQLContext
2424
import org.apache.spark.sql.test.SQLTestData.DecimalData
25-
import org.apache.spark.sql.types.DecimalType
25+
import org.apache.spark.sql.types.{Decimal, DecimalType}
2626

2727
case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)
2828

@@ -430,4 +430,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
430430
expr("kurtosis(a)")),
431431
Row(null, null, null, null, null))
432432
}
433+
434+
test("SPARK-14664: Decimal sum/avg over window should work.") {
435+
checkAnswer(
436+
sqlContext.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
437+
Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil)
438+
checkAnswer(
439+
sqlContext.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"),
440+
Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil)
441+
}
433442
}

0 commit comments

Comments
 (0)