Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"),

// [SPARK-14451][SQL] Move encoder definition into Aggregator interface
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"),

ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,16 @@ class ReplSuite extends SparkFunSuite {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.{Encoder, Encoders}
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] {
| def zero: Int = 0 // The initial value.
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
| def finish(b: Int) = b // Return the final result.
| def bufferEncoder: Encoder[Int] = Encoders.scalaInt
| def outputEncoder: Encoder[Int] = Encoders.scalaInt
|}.toColumn
|
|val ds = Seq(1, 2, 3, 4).toDS()
Expand Down Expand Up @@ -339,30 +341,6 @@ class ReplSuite extends SparkFunSuite {
}
}

test("Datasets agg type-inference") {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
|class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
| val numeric = implicitly[Numeric[N]]
| override def zero: N = numeric.zero
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
| override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
| override def finish(reduction: N): N = reduction
|}
|
|def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
|val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
|ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,16 @@ class ReplSuite extends SparkFunSuite {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.{Encoder, Encoders}
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] {
| def zero: Int = 0 // The initial value.
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
| def finish(b: Int) = b // Return the final result.
| def bufferEncoder: Encoder[Int] = Encoders.scalaInt
| def outputEncoder: Encoder[Int] = Encoders.scalaInt
|}.toColumn
|
|val ds = Seq(1, 2, 3, 4).toDS()
Expand Down Expand Up @@ -321,31 +323,6 @@ class ReplSuite extends SparkFunSuite {
}
}

test("Datasets agg type-inference") {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
|class SumOf[I, N : Numeric](f: I => N) extends
| org.apache.spark.sql.expressions.Aggregator[I, N, N] {
| val numeric = implicitly[Numeric[N]]
| override def zero: N = numeric.zero
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
| override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
| override def finish(reduction: N): N = reduction
|}
|
|def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
|val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
|ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.aggregate

import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.sql.TypedColumn
import org.apache.spark.sql.{Encoder, TypedColumn}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator

Expand All @@ -27,28 +27,20 @@ import org.apache.spark.sql.expressions.Aggregator
////////////////////////////////////////////////////////////////////////////////////////////////////


class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] {
val numeric = implicitly[Numeric[OUT]]
override def zero: OUT = numeric.zero
override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
override def finish(reduction: OUT): OUT = reduction

// TODO(ekl) java api support once this is exposed in scala
}


class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
override def zero: Double = 0.0
override def reduce(b: Double, a: IN): Double = b + f(a)
override def merge(b1: Double, b2: Double): Double = b1 + b2
override def finish(reduction: Double): Double = reduction

override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use Encoders.scalaDouble?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is internal, so it is not that bad to use the internal api.

override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()

// Java api support
def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
toColumn(ExpressionEncoder(), ExpressionEncoder())
.asInstanceOf[TypedColumn[IN, java.lang.Double]]

def toColumnJava: TypedColumn[IN, java.lang.Double] = {
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
}
}

Expand All @@ -59,11 +51,14 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction

override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()

// Java api support
def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
toColumn(ExpressionEncoder(), ExpressionEncoder())
.asInstanceOf[TypedColumn[IN, java.lang.Long]]

def toColumnJava: TypedColumn[IN, java.lang.Long] = {
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
}
}

Expand All @@ -76,11 +71,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction

override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()

// Java api support
def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
toColumn(ExpressionEncoder(), ExpressionEncoder())
.asInstanceOf[TypedColumn[IN, java.lang.Long]]
def toColumnJava: TypedColumn[IN, java.lang.Long] = {
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
}
}

Expand All @@ -93,10 +90,12 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D
(b1._1 + b2._1, b1._2 + b2._2)
}

override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]()
override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()

// Java api support
def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
toColumn(ExpressionEncoder(), ExpressionEncoder())
.asInstanceOf[TypedColumn[IN, java.lang.Double]]
def toColumnJava: TypedColumn[IN, java.lang.Double] = {
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,52 +43,65 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
*
* Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird
*
* @tparam I The input type for the aggregation.
* @tparam B The type of the intermediate value of the reduction.
* @tparam O The type of the final output result.
* @tparam IN The input type for the aggregation.
* @tparam BUF The type of the intermediate value of the reduction.
* @tparam OUT The type of the final output result.
* @since 1.6.0
*/
abstract class Aggregator[-I, B, O] extends Serializable {
abstract class Aggregator[-IN, BUF, OUT] extends Serializable {

/**
* A zero value for this aggregation. Should satisfy the property that any b + zero = b.
* @since 1.6.0
*/
def zero: B
def zero: BUF

/**
* Combine two values to produce a new value. For performance, the function may modify `b` and
* return it instead of constructing new object for b.
* @since 1.6.0
*/
def reduce(b: B, a: I): B
def reduce(b: BUF, a: IN): BUF

/**
* Merge two intermediate values.
* @since 1.6.0
*/
def merge(b1: B, b2: B): B
def merge(b1: BUF, b2: BUF): BUF

/**
* Transform the output of the reduction.
* @since 1.6.0
*/
def finish(reduction: B): O
def finish(reduction: BUF): OUT

/**
* Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]]
* Specifies the [[Encoder]] for the intermediate value type.
* @since 2.0.0
*/
def bufferEncoder: Encoder[BUF]

/**
* Specifies the [[Encoder]] for the final ouput value type.
* @since 2.0.0
*/
def outputEncoder: Encoder[OUT]

/**
* Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]].
* operations.
* @since 1.6.0
*/
def toColumn(
implicit bEncoder: Encoder[B],
cEncoder: Encoder[O]): TypedColumn[I, O] = {
def toColumn: TypedColumn[IN, OUT] = {
implicit val bEncoder = bufferEncoder
implicit val cEncoder = outputEncoder

val expr =
AggregateExpression(
TypedAggregateExpression(this),
Complete,
isDistinct = false)

new TypedColumn[I, O](expr, encoderFor[O])
new TypedColumn[IN, OUT](expr, encoderFor[OUT])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.KeyValueGroupedDataset;
import org.apache.spark.sql.expressions.Aggregator;
Expand All @@ -39,12 +40,10 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
public void testTypedAggregationAnonClass() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();

Dataset<Tuple2<String, Integer>> agged =
grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn());
Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());

Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn())
.as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
Assert.assertEquals(
Arrays.asList(
Expand Down Expand Up @@ -73,6 +72,16 @@ public Integer merge(Integer b1, Integer b2) {
public Integer finish(Integer reduction) {
return reduction;
}

@Override
public Encoder<Integer> bufferEncoder() {
return Encoders.INT();
}

@Override
public Encoder<Integer> outputEncoder() {
return Encoders.INT();
}
}

@Test
Expand Down
Loading