Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
95156fd
Add new UDF subclass UserDefinedAggregator to expose Aggregator for D…
erikerlandson Jun 29, 2019
4a1e01a
update for UserDefinedAggregator
erikerlandson Oct 18, 2019
b0ec357
rename to UDAQuerySuite.scala
erikerlandson Oct 18, 2019
f7ac390
tweak register call to disambiguate the signature
erikerlandson Oct 18, 2019
dbef793
rewrite unit test for UserDefinedAggregator
erikerlandson Oct 18, 2019
3fa704c
remove tdigest
erikerlandson Oct 18, 2019
8e886f6
rename to registerAggregator to avoid overloading conflict
erikerlandson Oct 21, 2019
a85ab11
restore udaf and aggregator imports
erikerlandson Oct 21, 2019
e76c0bc
UserDefinedAggregator is experimental
erikerlandson Oct 21, 2019
fc57259
tdigest deps snuck back in
erikerlandson Oct 21, 2019
651dedd
private[sql] UserDefinedAggregator
erikerlandson Dec 20, 2019
cc48d05
udf.registerAggregator(agg) -> udf.register(functions.udaf(agg))
erikerlandson Dec 21, 2019
6e4d99f
get rid of some vals for reduced serialization
erikerlandson Dec 21, 2019
5ae329a
use UnsafeProjection instead of MutableProjection
erikerlandson Dec 21, 2019
a584a33
use toRow() method in eval
erikerlandson Dec 21, 2019
7b1ddbe
update documentation
erikerlandson Dec 21, 2019
912a87a
resynch UDAQuerySuite to use getErrorMessageInCheckAnswer
erikerlandson Dec 22, 2019
ad56238
remove [[]] for scaladoc
erikerlandson Dec 24, 2019
be55f93
remove unused method
erikerlandson Jan 3, 2020
b51506c
use nodeName method in toString
erikerlandson Jan 3, 2020
0c0e46e
defer resolveAndBind to ScalaAggregator
erikerlandson Jan 3, 2020
ef8392e
simplify serialize() method with toRow()
erikerlandson Jan 5, 2020
b4ad3bc
simplify deserialize with fromRow
erikerlandson Jan 6, 2020
6012f30
code cleanup in ScalaAggregator
erikerlandson Jan 6, 2020
4eaf751
use 'actual' directly in UDAQuerySuite
erikerlandson Jan 6, 2020
986a3b4
invoke super.checkAnswer
erikerlandson Jan 6, 2020
1dead9d
simplify nullable as case class member
erikerlandson Jan 7, 2020
3f0cfec
update 'udaf' scaldoc to reflect the overloading for java
erikerlandson Jan 7, 2020
eb95998
explicitly test type casting failure
erikerlandson Jan 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF}
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedFunction}
import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -101,9 +102,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 2.2.0
*/
def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {
def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
functionRegistry.createOrReplaceTempFunction(name, builder)
udf
udf match {
case udaf: UserDefinedAggregator[_, _, _] =>
def builder(children: Seq[Expression]) = udaf.scalaAggregator(children)
functionRegistry.createOrReplaceTempFunction(name, builder)
udf
case _ =>
def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
functionRegistry.createOrReplaceTempFunction(name, builder)
udf
}
}

// scalastyle:off line.size.limit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

package org.apache.spark.sql.execution.aggregate

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, _}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -450,3 +454,63 @@ case class ScalaUDAF(

override def nodeName: String = udaf.getClass.getSimpleName
}

case class ScalaAggregator[IN, BUF, OUT](
children: Seq[Expression],
agg: Aggregator[IN, BUF, OUT],
inputEncoderNR: ExpressionEncoder[IN],
nullable: Boolean = true,
isDeterministic: Boolean = true,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[BUF]
with NonSQLExpression
with UserDefinedExpression
with ImplicitCastInputTypes
with Logging {

private[this] lazy val inputEncoder = inputEncoderNR.resolveAndBind()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the problem. We shouldn't keep the encoder unresolved in the query plan, and resolve it in the executor side. We can follow ResolveEncodersInUDF: add a rule to resolve the encoders in ScalaAggregator at driver side.

cc @viirya @dongjoon-hyun

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this defers resolving encoder to executors, we should resolve it on the driver.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pinging me, @cloud-fan .

private[this] lazy val bufferEncoder =
agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind()
private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]]

def dataType: DataType = outputEncoder.objSerializer.dataType

def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType)

override lazy val deterministic: Boolean = isDeterministic

def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ScalaAggregator[IN, BUF, OUT] =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ScalaAggregator[IN, BUF, OUT] =
copy(inputAggBufferOffset = newInputAggBufferOffset)

private[this] lazy val inputProjection = UnsafeProjection.create(children)

def createAggregationBuffer(): BUF = agg.zero

def update(buffer: BUF, input: InternalRow): BUF =
agg.reduce(buffer, inputEncoder.fromRow(inputProjection(input)))

def merge(buffer: BUF, input: BUF): BUF = agg.merge(buffer, input)

def eval(buffer: BUF): Any = {
val row = outputEncoder.toRow(agg.finish(buffer))
if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType)
}

private[this] lazy val bufferRow = new UnsafeRow(bufferEncoder.namedExpressions.length)

def serialize(agg: BUF): Array[Byte] =
bufferEncoder.toRow(agg).asInstanceOf[UnsafeRow].getBytes()

def deserialize(storageFormat: Array[Byte]): BUF = {
bufferRow.pointTo(storageFormat, storageFormat.length)
bufferEncoder.fromRow(bufferRow)
}

override def toString: String = s"""${nodeName}(${children.mkString(",")})"""

override def nodeName: String = agg.getClass.getSimpleName
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@

package org.apache.spark.sql.expressions

import org.apache.spark.annotation.Stable
import org.apache.spark.sql.Column
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.annotation.{Experimental, Stable}
import org.apache.spark.sql.{Column, Encoder}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.ScalaAggregator
import org.apache.spark.sql.types.{AnyDataType, DataType}

/**
Expand Down Expand Up @@ -136,3 +141,42 @@ private[sql] case class SparkUserDefinedFunction(
}
}
}

private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
aggregator: Aggregator[IN, BUF, OUT],
inputEncoder: Encoder[IN],
name: Option[String] = None,
nullable: Boolean = true,
deterministic: Boolean = true) extends UserDefinedFunction {

@scala.annotation.varargs
def apply(exprs: Column*): Column = {
Column(AggregateExpression(scalaAggregator(exprs.map(_.expr)), Complete, isDistinct = false))
}

// This is also used by udf.register(...) when it detects a UserDefinedAggregator
def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = {
val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]]
ScalaAggregator(exprs, aggregator, iEncoder, nullable, deterministic)
}

override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = {
copy(name = Option(name))
}

override def asNonNullable(): UserDefinedAggregator[IN, BUF, OUT] = {
if (!nullable) {
this
} else {
copy(nullable = false)
}
}

override def asNondeterministic(): UserDefinedAggregator[IN, BUF, OUT] = {
if (!deterministic) {
this
} else {
copy(deterministic = false)
}
}
}
64 changes: 62 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction}
import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


/**
* Commonly used functions available for DataFrame operations. Using functions defined here provides
* a little bit more compile-time safety to make sure the function exists.
Expand Down Expand Up @@ -4231,6 +4230,67 @@ object functions {
// Scala UDF functions
//////////////////////////////////////////////////////////////////////////////////////////////

/**
* Obtains a `UserDefinedFunction` that wraps the given `Aggregator`
* so that it may be used with untyped Data Frames.
* {{{
* val agg = // Aggregator[IN, BUF, OUT]
*
* // declare a UDF based on agg
* val aggUDF = udaf(agg)
* val aggData = df.agg(aggUDF($"colname"))
*
* // register agg as a named function
* spark.udf.register("myAggName", udaf(agg))
* }}}
*
* @tparam IN the aggregator input type
* @tparam BUF the aggregating buffer type
* @tparam OUT the finalized output type
*
* @param agg the typed Aggregator
*
* @return a UserDefinedFunction that can be used as an aggregating expression.
*
* @note The input encoder is inferred from the input type IN.
*/
def udaf[IN: TypeTag, BUF, OUT](agg: Aggregator[IN, BUF, OUT]): UserDefinedFunction = {
udaf(agg, ExpressionEncoder[IN]())
}

/**
* Obtains a `UserDefinedFunction` that wraps the given `Aggregator`
* so that it may be used with untyped Data Frames.
* {{{
* Aggregator<IN, BUF, OUT> agg = // custom Aggregator
* Encoder<IN> enc = // input encoder
*
* // declare a UDF based on agg
* UserDefinedFunction aggUDF = udaf(agg, enc)
* DataFrame aggData = df.agg(aggUDF($"colname"))
*
* // register agg as a named function
* spark.udf.register("myAggName", udaf(agg, enc))
* }}}
*
* @tparam IN the aggregator input type
* @tparam BUF the aggregating buffer type
* @tparam OUT the finalized output type
*
* @param agg the typed Aggregator
* @param inputEncoder a specific input encoder to use
*
* @return a UserDefinedFunction that can be used as an aggregating expression
*
* @note This overloading takes an explicit input encoder, to support UDAF
* declarations in Java.
*/
def udaf[IN, BUF, OUT](
agg: Aggregator[IN, BUF, OUT],
inputEncoder: Encoder[IN]): UserDefinedFunction = {
UserDefinedAggregator(agg, inputEncoder)
}

/**
* Defines a Scala closure of 0 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the Scala closure's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.scalatest.Matchers.the
import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.sql.catalyst.optimizer.TransposeWindow
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -412,6 +412,42 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSparkSession {
Row("b", 2, 4, 8)))
}

test("window function with aggregator") {
val agg = udaf(new Aggregator[(Long, Long), Long, Long] {
def zero: Long = 0L
def reduce(b: Long, a: (Long, Long)): Long = b + (a._1 * a._2)
def merge(b1: Long, b2: Long): Long = b1 + b2
def finish(r: Long): Long = r
def bufferEncoder: Encoder[Long] = Encoders.scalaLong
def outputEncoder: Encoder[Long] = Encoders.scalaLong
})

val df = Seq(
("a", 1, 1),
("a", 1, 5),
("a", 2, 10),
("a", 2, -1),
("b", 4, 7),
("b", 3, 8),
("b", 2, 4))
.toDF("key", "a", "b")
val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L)
checkAnswer(
df.select(
$"key",
$"a",
$"b",
agg($"a", $"b").over(window)),
Seq(
Row("a", 1, 1, 6),
Row("a", 1, 5, 6),
Row("a", 2, 10, 24),
Row("a", 2, -1, 24),
Row("b", 4, 7, 60),
Row("b", 3, 8, 32),
Row("b", 2, 4, 8)))
}

test("null inputs") {
val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2))
.toDF("key", "value")
Expand Down
Loading