Skip to content

Commit 1a47e10

Browse files
committed
Renamed apply to eval for generators and added a bunch of override's.
1 parent ea061de commit 1a47e10

File tree

6 files changed

+50
-52
lines changed

6 files changed

+50
-52
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst.trees
2120
import org.apache.spark.sql.catalyst.errors.TreeNodeException
21+
import org.apache.spark.sql.catalyst.trees
2222
import org.apache.spark.sql.catalyst.trees.TreeNode
2323
import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType}
2424

@@ -231,7 +231,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
231231

232232
override def foldable = left.foldable && right.foldable
233233

234-
def references = left.references ++ right.references
234+
override def references = left.references ++ right.references
235235

236236
override def toString = s"($left $symbol $right)"
237237
}
@@ -243,5 +243,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
243243
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
244244
self: Product =>
245245

246-
def references = child.references
246+
override def references = child.references
247247
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ case class SplitEvaluation(
4343
partialEvaluations: Seq[NamedExpression])
4444

4545
/**
46-
* An [[AggregateExpression]] that can be partially computed without seeing all relevent tuples.
46+
* An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples.
4747
* These partial evaluations can then be combined to compute the actual answer.
4848
*/
4949
abstract class PartialAggregate extends AggregateExpression {
@@ -63,28 +63,28 @@ abstract class AggregateFunction
6363
extends AggregateExpression with Serializable with trees.LeafNode[Expression] {
6464
self: Product =>
6565

66-
type EvaluatedType = Any
66+
override type EvaluatedType = Any
6767

6868
/** Base should return the generic aggregate expression that this function is computing */
6969
val base: AggregateExpression
70-
def references = base.references
71-
def nullable = base.nullable
72-
def dataType = base.dataType
70+
override def references = base.references
71+
override def nullable = base.nullable
72+
override def dataType = base.dataType
7373

7474
def update(input: Row): Unit
7575
override def eval(input: Row): Any
7676

7777
// Do we really need this?
78-
def newInstance = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
78+
override def newInstance = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
7979
}
8080

8181
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
82-
def references = child.references
83-
def nullable = false
84-
def dataType = IntegerType
82+
override def references = child.references
83+
override def nullable = false
84+
override def dataType = IntegerType
8585
override def toString = s"COUNT($child)"
8686

87-
def asPartial: SplitEvaluation = {
87+
override def asPartial: SplitEvaluation = {
8888
val partialCount = Alias(Count(child), "PartialCount")()
8989
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
9090
}
@@ -93,18 +93,18 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
9393
}
9494

9595
case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
96-
def children = expressions
97-
def references = expressions.flatMap(_.references).toSet
98-
def nullable = false
99-
def dataType = IntegerType
96+
override def children = expressions
97+
override def references = expressions.flatMap(_.references).toSet
98+
override def nullable = false
99+
override def dataType = IntegerType
100100
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
101101
override def newInstance = new CountDistinctFunction(expressions, this)
102102
}
103103

104104
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
105-
def references = child.references
106-
def nullable = false
107-
def dataType = DoubleType
105+
override def references = child.references
106+
override def nullable = false
107+
override def dataType = DoubleType
108108
override def toString = s"AVG($child)"
109109

110110
override def asPartial: SplitEvaluation = {
@@ -122,9 +122,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
122122
}
123123

124124
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
125-
def references = child.references
126-
def nullable = false
127-
def dataType = child.dataType
125+
override def references = child.references
126+
override def nullable = false
127+
override def dataType = child.dataType
128128
override def toString = s"SUM($child)"
129129

130130
override def asPartial: SplitEvaluation = {
@@ -140,18 +140,18 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
140140
case class SumDistinct(child: Expression)
141141
extends AggregateExpression with trees.UnaryNode[Expression] {
142142

143-
def references = child.references
144-
def nullable = false
145-
def dataType = child.dataType
143+
override def references = child.references
144+
override def nullable = false
145+
override def dataType = child.dataType
146146
override def toString = s"SUM(DISTINCT $child)"
147147

148148
override def newInstance = new SumDistinctFunction(child, this)
149149
}
150150

151151
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
152-
def references = child.references
153-
def nullable = child.nullable
154-
def dataType = child.dataType
152+
override def references = child.references
153+
override def nullable = child.nullable
154+
override def dataType = child.dataType
155155
override def toString = s"FIRST($child)"
156156

157157
override def asPartial: SplitEvaluation = {
@@ -172,14 +172,12 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
172172
private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow))
173173
private val sumAsDouble = Cast(sum, DoubleType)
174174

175-
176-
177175
private val addFunction = Add(sum, expr)
178176

179177
override def eval(input: Row): Any =
180178
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
181179

182-
def update(input: Row): Unit = {
180+
override def update(input: Row): Unit = {
183181
count += 1
184182
sum.update(addFunction, input)
185183
}
@@ -190,7 +188,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
190188

191189
var count: Int = _
192190

193-
def update(input: Row): Unit = {
191+
override def update(input: Row): Unit = {
194192
val evaluatedExpr = expr.map(_.eval(input))
195193
if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
196194
count += 1
@@ -207,7 +205,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
207205

208206
private val addFunction = Add(sum, expr)
209207

210-
def update(input: Row): Unit = {
208+
override def update(input: Row): Unit = {
211209
sum.update(addFunction, input)
212210
}
213211

@@ -219,9 +217,9 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
219217

220218
def this() = this(null, null) // Required for serialization.
221219

222-
val seen = new scala.collection.mutable.HashSet[Any]()
220+
private val seen = new scala.collection.mutable.HashSet[Any]()
223221

224-
def update(input: Row): Unit = {
222+
override def update(input: Row): Unit = {
225223
val evaluatedExpr = expr.eval(input)
226224
if (evaluatedExpr != null) {
227225
seen += evaluatedExpr
@@ -239,7 +237,7 @@ case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpressio
239237

240238
val seen = new scala.collection.mutable.HashSet[Any]()
241239

242-
def update(input: Row): Unit = {
240+
override def update(input: Row): Unit = {
243241
val evaluatedExpr = expr.map(_.eval(input))
244242
if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) {
245243
seen += evaluatedExpr
@@ -254,7 +252,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
254252

255253
var result: Any = null
256254

257-
def update(input: Row): Unit = {
255+
override def update(input: Row): Unit = {
258256
if (result == null) {
259257
result = expr.eval(input)
260258
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,17 @@ import org.apache.spark.sql.catalyst.types._
3535
* requested. The attributes produced by this function will be automatically copied anytime rules
3636
* result in changes to the Generator or its children.
3737
*/
38-
abstract class Generator extends Expression with (Row => TraversableOnce[Row]) {
38+
abstract class Generator extends Expression {
3939
self: Product =>
4040

41-
type EvaluatedType = TraversableOnce[Row]
41+
override type EvaluatedType = TraversableOnce[Row]
4242

43-
lazy val dataType =
43+
override lazy val dataType =
4444
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable))))
4545

46-
def nullable = false
46+
override def nullable = false
4747

48-
def references = children.flatMap(_.references).toSet
48+
override def references = children.flatMap(_.references).toSet
4949

5050
/**
5151
* Should be overridden by specific generators. Called only once for each instance to ensure
@@ -63,7 +63,7 @@ abstract class Generator extends Expression with (Row => TraversableOnce[Row]) {
6363
}
6464

6565
/** Should be implemented by child classes to perform specific Generators. */
66-
def apply(input: Row): TraversableOnce[Row]
66+
override def eval(input: Row): TraversableOnce[Row]
6767

6868
/** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
6969
override def makeCopy(newArgs: Array[AnyRef]): this.type = {
@@ -83,7 +83,7 @@ case class Explode(attributeNames: Seq[String], child: Expression)
8383
child.resolved &&
8484
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
8585

86-
lazy val elementTypes = child.dataType match {
86+
private lazy val elementTypes = child.dataType match {
8787
case ArrayType(et) => et :: Nil
8888
case MapType(kt,vt) => kt :: vt :: Nil
8989
}
@@ -100,7 +100,7 @@ case class Explode(attributeNames: Seq[String], child: Expression)
100100
}
101101
}
102102

103-
override def apply(input: Row): TraversableOnce[Row] = {
103+
override def eval(input: Row): TraversableOnce[Row] = {
104104
child.dataType match {
105105
case ArrayType(_) =>
106106
val inputArray = child.eval(input).asInstanceOf[Seq[Any]]

sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ case class Generate(
3636
child: SparkPlan)
3737
extends UnaryNode {
3838

39-
def output =
39+
override def output =
4040
if (join) child.output ++ generator.output else generator.output
4141

42-
def execute() = {
42+
override def execute() = {
4343
if (join) {
4444
child.execute().mapPartitions { iter =>
4545
val nullValues = Seq.fill(generator.output.size)(Literal(null))
@@ -52,7 +52,7 @@ case class Generate(
5252
val joinedRow = new JoinedRow
5353

5454
iter.flatMap {row =>
55-
val outputRows = generator(row)
55+
val outputRows = generator.eval(row)
5656
if (outer && outputRows.isEmpty) {
5757
outerProjection(row) :: Nil
5858
} else {
@@ -61,7 +61,7 @@ case class Generate(
6161
}
6262
}
6363
} else {
64-
child.execute().mapPartitions(iter => iter.flatMap(generator))
64+
child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row)))
6565
}
6666
}
6767
}

sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generato
3939

4040
val Seq(nameAttr, ageAttr) = input
4141

42-
override def apply(input: Row): TraversableOnce[Row] = {
42+
override def eval(input: Row): TraversableOnce[Row] = {
4343
val name = nameAttr.eval(input)
4444
val age = ageAttr.eval(input).asInstanceOf[Int]
4545

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ case class HiveGenericUdtf(
403403
}
404404
}
405405

406-
override def apply(input: Row): TraversableOnce[Row] = {
406+
override def eval(input: Row): TraversableOnce[Row] = {
407407
outputInspectors // Make sure initialized.
408408

409409
val inputProjection = new Projection(children)

0 commit comments

Comments
 (0)