diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala index f8af91980bde..5510f0019353 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala @@ -44,6 +44,12 @@ object SimpleTypedAggregator { println("running typed average:") ds.groupByKey(_._1).agg(new TypedAverage[(Long, Long)](_._2.toDouble).toColumn).show() + println("running typed minimum:") + ds.groupByKey(_._1).agg(new TypedMin[(Long, Long)](_._2.toDouble).toColumn).show() + + println("running typed maximum:") + ds.groupByKey(_._1).agg(new TypedMax[(Long, Long)](_._2).toColumn).show() + spark.stop() } } @@ -84,3 +90,71 @@ class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long } override def outputEncoder: Encoder[Double] = Encoders.scalaDouble } + +class TypedMin[IN](val f: IN => Double) extends Aggregator[IN, MutableDouble, Option[Double]] { + override def zero: MutableDouble = null + override def reduce(b: MutableDouble, a: IN): MutableDouble = { + if (b == null) { + new MutableDouble(f(a)) + } else { + b.value = math.min(b.value, f(a)) + b + } + } + override def merge(b1: MutableDouble, b2: MutableDouble): MutableDouble = { + if (b1 == null) { + b2 + } else if (b2 == null) { + b1 + } else { + b1.value = math.min(b1.value, b2.value) + b1 + } + } + override def finish(reduction: MutableDouble): Option[Double] = { + if (reduction != null) { + Some(reduction.value) + } else { + None + } + } + + override def bufferEncoder: Encoder[MutableDouble] = Encoders.kryo[MutableDouble] + override def outputEncoder: Encoder[Option[Double]] = Encoders.product[Option[Double]] +} + +class TypedMax[IN](val f: IN => Long) extends Aggregator[IN, MutableLong, Option[Long]] { + override def zero: MutableLong = null + override def reduce(b: MutableLong, a: IN): MutableLong = { + if (b == null) { + new MutableLong(f(a)) + } else { + b.value = math.max(b.value, f(a)) + b + } + } + override def merge(b1: MutableLong, b2: MutableLong): MutableLong = { + if (b1 == null) { + b2 + } else if (b2 == null) { + b1 + } else { + b1.value = math.max(b1.value, b2.value) + b1 + } + } + override def finish(reduction: MutableLong): Option[Long] = { + if (reduction != null) { + Some(reduction.value) + } else { + None + } + } + + override def bufferEncoder: Encoder[MutableLong] = Encoders.kryo[MutableLong] + override def outputEncoder: Encoder[Option[Long]] = Encoders.product[Option[Long]] +} + +class MutableLong(var value: Long) extends Serializable + +class MutableDouble(var value: Double) extends Serializable