Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,19 @@ import org.apache.spark.util.Utils
* to the name `value`.
*/
object ExpressionEncoder {
def apply[T : TypeTag](): ExpressionEncoder[T] = {
// Constructs an encoder for top-level row.
def apply[T : TypeTag](): ExpressionEncoder[T] = apply(topLevel = true)

/**
* @param topLevel whether the encoders to construct are for top-level row.
*/
def apply[T : TypeTag](topLevel: Boolean): ExpressionEncoder[T] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we call this apply with topLevel = false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Aggregator, we can call this apply with topLevel = false to avoid resulting a nested struct.

// We convert the not-serializable TypeTag into StructType and ClassTag.
val mirror = ScalaReflection.mirror
val tpe = typeTag[T].in(mirror).tpe

if (ScalaReflection.optionOfProductType(tpe)) {
// For non top-level encodes, we allow using Option of Product type.
if (topLevel && ScalaReflection.optionOfProductType(tpe)) {
throw new UnsupportedOperationException(
"Cannot create encoder for Option of Product type, because Product type is represented " +
"as a row, and the entire row can not be null in Spark SQL like normal databases. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] {


case class OptionBooleanData(name: String, isGood: Option[Boolean])
case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)])

case class OptionBooleanAggregator(colName: String)
extends Aggregator[Row, Option[Boolean], Option[Boolean]] {
Expand Down Expand Up @@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String)
def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder()
}

case class OptionBooleanIntAggregator(colName: String)
extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the expected schema after we apply an aggregator with Option[Product] as buffer/output?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a non top-level encoder, the output schema of Option[Product] should be struct column.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assuming non top level, Option[Product] is same as Product?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. For non top level, [Option[Product] is same as Product. The difference is additional WrapOption and UnwrapOption around expressions.


override def zero: Option[(Boolean, Int)] = None

override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = {
val index = row.fieldIndex(colName)
val value = if (row.isNullAt(index)) {
Option.empty[(Boolean, Int)]
} else {
val nestedRow = row.getStruct(index)
Some((nestedRow.getBoolean(0), nestedRow.getInt(1)))
}
merge(buffer, value)
}

override def merge(
b1: Option[(Boolean, Int)],
b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = {
if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) {
val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0)
Some((true, newInt))
} else if (b1.isDefined) {
b1
} else {
b2
}
}

override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction

override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder
override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder

def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder(topLevel = false)
}

class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._

Expand Down Expand Up @@ -393,4 +431,17 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
assert(grouped.schema == df.schema)
checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true)))
}

test("SPARK-24762: Aggregator should be able to use Option of Product encoder") {
val df = Seq(
OptionBooleanIntData("bob", Some((true, 1))),
OptionBooleanIntData("bob", Some((false, 2))),
OptionBooleanIntData("bob", None)).toDF()
val group = df
.groupBy("name")
.agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood"))
assert(df.schema == group.schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's write down the expected schema

val expectedSchema = ...
assert(df.schema == expectedSchema)
assert(grouped.schema == ...)

checkAnswer(group, Row("bob", Row(true, 3)) :: Nil)
checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3))))
}
}