Skip to content

Commit 2e6899f

Browse files
committed
reduced duplicate code
1 parent 4816c2e commit 2e6899f

File tree

1 file changed

+28
-36
lines changed

1 file changed

+28
-36
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,85 +17,77 @@
1717

1818
package org.apache.spark.sql.execution.aggregate
1919

20+
import scala.reflect.runtime.universe.TypeTag
21+
2022
import org.apache.spark.api.java.function.MapFunction
2123
import org.apache.spark.sql.{Encoder, TypedColumn}
2224
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2325
import org.apache.spark.sql.expressions.Aggregator
2426

27+
2528
////////////////////////////////////////////////////////////////////////////////////////////////////
2629
// This file defines internal implementations for aggregators.
2730
////////////////////////////////////////////////////////////////////////////////////////////////////
2831

32+
abstract class TypedAggregator[IN, BUF: TypeTag, OUT: TypeTag, FRESULT, JAVA](f: IN => FRESULT)
33+
extends Aggregator[IN, BUF, OUT] {
34+
35+
def bufferEncoder: Encoder[BUF] = ExpressionEncoder[BUF]()
36+
def outputEncoder: Encoder[OUT] = ExpressionEncoder[OUT]()
37+
38+
def toColumnJava: TypedColumn[IN, JAVA] = {
39+
toColumn.asInstanceOf[TypedColumn[IN, JAVA]]
40+
}
41+
}
42+
43+
class TypedSumDouble[IN](val f: IN => Double)
44+
extends TypedAggregator[IN, Double, Double, Double, java.lang.Double](f) {
2945

30-
class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Double] {
3146
override def zero: Double = 0.0
3247
override def reduce(b: Double, a: IN): Double = b + f(a)
3348
override def merge(b1: Double, b2: Double): Double = b1 + b2
3449
override def finish(reduction: Double): Double = reduction
3550

36-
override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]()
37-
override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
38-
39-
// Java api support
51+
// Java constructor
4052
def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
41-
42-
def toColumnJava: TypedColumn[IN, java.lang.Double] = {
43-
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
44-
}
4553
}
4654

4755

48-
class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] {
56+
class TypedSumLong[IN](val f: IN => Long)
57+
extends TypedAggregator[IN, Long, Long, Long, java.lang.Long](f) {
58+
4959
override def zero: Long = 0L
5060
override def reduce(b: Long, a: IN): Long = b + f(a)
5161
override def merge(b1: Long, b2: Long): Long = b1 + b2
5262
override def finish(reduction: Long): Long = reduction
5363

54-
override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
55-
override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
56-
57-
// Java api support
64+
// Java constructor
5865
def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
59-
60-
def toColumnJava: TypedColumn[IN, java.lang.Long] = {
61-
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
62-
}
6366
}
6467

68+
class TypedCount[IN](val f: IN => Any)
69+
extends TypedAggregator[IN, Long, Long, Any, java.lang.Long](f) {
6570

66-
class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] {
6771
override def zero: Long = 0
68-
override def reduce(b: Long, a: IN): Long = {
69-
if (f(a) == null) b else b + 1
70-
}
72+
override def reduce(b: Long, a: IN): Long = if (f(a) == null) b else b + 1
7173
override def merge(b1: Long, b2: Long): Long = b1 + b2
7274
override def finish(reduction: Long): Long = reduction
7375

74-
override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
75-
override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
76-
77-
// Java api support
76+
// Java constructor
7877
def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
79-
def toColumnJava: TypedColumn[IN, java.lang.Long] = {
80-
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
81-
}
78+
8279
}
8380

81+
class TypedAverage[IN](val f: IN => Double)
82+
extends TypedAggregator[IN, (Double, Long), Double, Double, java.lang.Double](f) {
8483

85-
class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
8684
override def zero: (Double, Long) = (0.0, 0L)
8785
override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2)
8886
override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2
8987
override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
9088
(b1._1 + b2._1, b1._2 + b2._2)
9189
}
9290

93-
override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]()
94-
override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
95-
9691
// Java api support
9792
def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
98-
def toColumnJava: TypedColumn[IN, java.lang.Double] = {
99-
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
100-
}
10193
}

0 commit comments

Comments
 (0)