Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit b7720ba

Browse files
committed
Add an analysis rule to convert aggregate function to the new version.
1 parent 5c00f3f commit b7720ba

File tree

5 files changed

+45
-13
lines changed

5 files changed

+45
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.AnalysisException
21-
import org.apache.spark.sql.catalyst.expressions.aggregate2.{AggregateExpression2, AggregateFunction2}
2221
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
2322
import org.apache.spark.sql.catalyst.expressions._
2423
import org.apache.spark.sql.catalyst.plans.logical._
@@ -483,11 +482,7 @@ class Analyzer(
483482
q transformExpressions {
484483
case u @ UnresolvedFunction(name, children) =>
485484
withPosition(u) {
486-
registry.lookupFunction(name, children) match {
487-
case agg2: AggregateFunction2 =>
488-
AggregateExpression2(agg2, aggregate2.Complete, false)
489-
case other => other
490-
}
485+
registry.lookupFunction(name, children)
491486
}
492487
}
493488
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ object FunctionRegistry {
148148

149149
// aggregate functions
150150
expression[Average]("avg"),
151-
expression[aggregate2.Average]("avg2"),
152151
expression[Count]("count"),
153152
expression[First]("first"),
154153
expression[Last]("last"),

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import java.beans.Introspector
2121
import java.util.Properties
2222
import java.util.concurrent.atomic.AtomicReference
2323

24+
import org.apache.spark.sql.execution.aggregate2.ConvertAggregateFunction
25+
2426
import scala.collection.JavaConversions._
2527
import scala.collection.immutable
2628
import scala.language.implicitConversions
@@ -148,6 +150,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
148150
override val extendedResolutionRules =
149151
ExtractPythonUDFs ::
150152
sources.PreInsertCastAndRename ::
153+
ConvertAggregateFunction(self) ::
151154
Nil
152155

153156
override val extendedCheckRules = Seq(

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate2Sort.scala renamed to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql.execution
18+
package org.apache.spark.sql.execution.aggregate2
1919

2020
import org.apache.spark.rdd.RDD
2121
import org.apache.spark.sql.catalyst.errors._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate2._
24-
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, AllTuples, UnspecifiedDistribution, Distribution}
24+
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
25+
import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
2526
import org.apache.spark.sql.types.NullType
2627

2728
case class Aggregate2Sort(
@@ -71,7 +72,7 @@ case class Aggregate2Sort(
7172
case PartialMerge | Final => func
7273
}
7374
bufferOffset = aggregateExpressions(i).mode match {
74-
case Partial | PartialMerge => bufferOffset + func.bufferValueDataTypes.length
75+
case Partial | PartialMerge => bufferOffset + func.bufferSchema.length
7576
case Final | Complete => bufferOffset + 1
7677
}
7778
i += 1
@@ -88,7 +89,7 @@ case class Aggregate2Sort(
8889
var i = 0
8990
var size = 0
9091
while (i < aggregateFunctions.length) {
91-
size += aggregateFunctions(i).bufferValueDataTypes.length
92+
size += aggregateFunctions(i).bufferSchema.length
9293
i += 1
9394
}
9495
if (preShuffle) {
@@ -132,7 +133,7 @@ case class Aggregate2Sort(
132133

133134
lazy val updateProjection = {
134135
val bufferSchema = aggregateFunctions.flatMap {
135-
case ae: AlgebraicAggregate => ae.bufferSchema
136+
case ae: AlgebraicAggregate => ae.bufferAttributes
136137
}
137138
val updateExpressions = aggregateFunctions.flatMap {
138139
case ae: AlgebraicAggregate => ae.updateExpressions
@@ -145,7 +146,7 @@ case class Aggregate2Sort(
145146
val mergeProjection = {
146147
val bufferSchemata =
147148
offsetAttributes ++ aggregateFunctions.flatMap {
148-
case ae: AlgebraicAggregate => ae.bufferSchema
149+
case ae: AlgebraicAggregate => ae.bufferAttributes
149150
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
150151
case ae: AlgebraicAggregate => ae.rightBufferSchema
151152
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.aggregate2
19+
20+
import org.apache.spark.sql.SQLContext
21+
import org.apache.spark.sql.catalyst.expressions.{Average => Average1}
22+
import org.apache.spark.sql.catalyst.expressions.aggregate2.{Average => Average2, AggregateExpression2, Complete}
23+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
26+
case class ConvertAggregateFunction(context: SQLContext) extends Rule[LogicalPlan] {
27+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
28+
case p: LogicalPlan if !p.childrenResolved => p
29+
30+
case p if context.conf.useSqlAggregate2 => p.transformExpressionsUp {
31+
case Average1(child) => AggregateExpression2(Average2(child), Complete, false)
32+
}
33+
}
34+
}

0 commit comments

Comments
 (0)