|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.execution.aggregate |
19 | 19 |
|
| 20 | +import scala.reflect.runtime.universe.TypeTag |
| 21 | + |
20 | 22 | import org.apache.spark.api.java.function.MapFunction |
21 | 23 | import org.apache.spark.sql.{Encoder, TypedColumn} |
22 | 24 | import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder |
23 | 25 | import org.apache.spark.sql.expressions.Aggregator |
24 | 26 |
|
| 27 | + |
25 | 28 | //////////////////////////////////////////////////////////////////////////////////////////////////// |
26 | 29 | // This file defines internal implementations for aggregators. |
27 | 30 | //////////////////////////////////////////////////////////////////////////////////////////////////// |
28 | 31 |
|
| 32 | +abstract class TypedAggregator[IN, BUF: TypeTag, OUT: TypeTag, FRESULT, JAVA](f: IN => FRESULT) |
| 33 | + extends Aggregator[IN, BUF, OUT] { |
| 34 | + |
| 35 | + def bufferEncoder: Encoder[BUF] = ExpressionEncoder[BUF]() |
| 36 | + def outputEncoder: Encoder[OUT] = ExpressionEncoder[OUT]() |
| 37 | + |
| 38 | + def toColumnJava: TypedColumn[IN, JAVA] = { |
| 39 | + toColumn.asInstanceOf[TypedColumn[IN, JAVA]] |
| 40 | + } |
| 41 | +} |
| 42 | + |
| 43 | +class TypedSumDouble[IN](val f: IN => Double) |
| 44 | + extends TypedAggregator[IN, Double, Double, Double, java.lang.Double](f) { |
29 | 45 |
|
30 | | -class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Double] { |
31 | 46 | override def zero: Double = 0.0 |
32 | 47 | override def reduce(b: Double, a: IN): Double = b + f(a) |
33 | 48 | override def merge(b1: Double, b2: Double): Double = b1 + b2 |
34 | 49 | override def finish(reduction: Double): Double = reduction |
35 | 50 |
|
36 | | - override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]() |
37 | | - override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() |
38 | | - |
39 | | - // Java api support |
| 51 | + // Java constructor |
40 | 52 | def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) |
41 | | - |
42 | | - def toColumnJava: TypedColumn[IN, java.lang.Double] = { |
43 | | - toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]] |
44 | | - } |
45 | 53 | } |
46 | 54 |
|
47 | 55 |
|
48 | | -class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] { |
| 56 | +class TypedSumLong[IN](val f: IN => Long) |
| 57 | + extends TypedAggregator[IN, Long, Long, Long, java.lang.Long](f) { |
| 58 | + |
49 | 59 | override def zero: Long = 0L |
50 | 60 | override def reduce(b: Long, a: IN): Long = b + f(a) |
51 | 61 | override def merge(b1: Long, b2: Long): Long = b1 + b2 |
52 | 62 | override def finish(reduction: Long): Long = reduction |
53 | 63 |
|
54 | | - override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() |
55 | | - override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() |
56 | | - |
57 | | - // Java api support |
| 64 | + // Java constructor |
58 | 65 | def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long]) |
59 | | - |
60 | | - def toColumnJava: TypedColumn[IN, java.lang.Long] = { |
61 | | - toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]] |
62 | | - } |
63 | 66 | } |
64 | 67 |
|
| 68 | +class TypedCount[IN](val f: IN => Any) |
| 69 | + extends TypedAggregator[IN, Long, Long, Any, java.lang.Long](f) { |
65 | 70 |
|
66 | | -class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] { |
67 | 71 | override def zero: Long = 0 |
68 | | - override def reduce(b: Long, a: IN): Long = { |
69 | | - if (f(a) == null) b else b + 1 |
70 | | - } |
| 72 | + override def reduce(b: Long, a: IN): Long = if (f(a) == null) b else b + 1 |
71 | 73 | override def merge(b1: Long, b2: Long): Long = b1 + b2 |
72 | 74 | override def finish(reduction: Long): Long = reduction |
73 | 75 |
|
74 | | - override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() |
75 | | - override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() |
76 | | - |
77 | | - // Java api support |
| 76 | + // Java constructor |
78 | 77 | def this(f: MapFunction[IN, Object]) = this(x => f.call(x)) |
79 | | - def toColumnJava: TypedColumn[IN, java.lang.Long] = { |
80 | | - toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]] |
81 | | - } |
| 78 | + |
82 | 79 | } |
83 | 80 |
|
| 81 | +class TypedAverage[IN](val f: IN => Double) |
| 82 | + extends TypedAggregator[IN, (Double, Long), Double, Double, java.lang.Double](f) { |
84 | 83 |
|
85 | | -class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long), Double] { |
86 | 84 | override def zero: (Double, Long) = (0.0, 0L) |
87 | 85 | override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2) |
88 | 86 | override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2 |
89 | 87 | override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = { |
90 | 88 | (b1._1 + b2._1, b1._2 + b2._2) |
91 | 89 | } |
92 | 90 |
|
93 | | - override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]() |
94 | | - override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() |
95 | | - |
96 | 91 | // Java api support |
97 | 92 | def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) |
98 | | - def toColumnJava: TypedColumn[IN, java.lang.Double] = { |
99 | | - toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]] |
100 | | - } |
101 | 93 | } |
0 commit comments