diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index a6867a67eead..7e600286e955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ou import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.expressions.ReduceAggregator /** * :: Experimental :: @@ -177,10 +178,10 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 1.6.0 */ def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { - val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) + val encoder = encoderFor[V] + val aggregator: TypedColumn[V, V] = new ReduceAggregator(f, encoder).toColumn - implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc) - flatMapGroups(func) + agg(aggregator) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala new file mode 100644 index 000000000000..0f013d49a2ca --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder + +/** + * :: Experimental :: + * A generic class for reduce aggregations, which accepts a reduce function that can be used to take + * all of the elements of a group and reduce them to a single value. + * + * @tparam T The input and output type for the reduce function. + * @param func The reduce aggregation function. + * @param encoder The encoder for the input and output type of the reduce function. + * @since 2.1.0 + */ +@Experimental +private[sql] class ReduceAggregator[T](func: (T, T) => T, encoder: ExpressionEncoder[T]) + extends Aggregator[T, (Boolean, T), T] { + + /** + * A zero value for this aggregation. It is represented as a Tuple2. The first element of the + * tuple is a false boolean value indicating the buffer is not initialized. The second element + * is initialized as a null value. + * @since 2.1.0 + */ + override def zero: (Boolean, T) = (false, null.asInstanceOf[T]) + + override def bufferEncoder: Encoder[(Boolean, T)] = + ExpressionEncoder.tuple(ExpressionEncoder[Boolean](), encoder) + + override def outputEncoder: Encoder[T] = encoder + + /** + * Combine two values to produce a new value. If the buffer `b` is not initialized, it simply + * takes the value of `a` and set the initialization flag to `true`. + * @since 2.1.0 + */ + override def reduce(b: (Boolean, T), a: T): (Boolean, T) = { + if (b._1) { + (true, func(b._2, a)) + } else { + (true, a) + } + } + + /** + * Merge two intermediate values. As it is possibly that the buffer is just the `zero` value + * coming from empty partition, it checks if the buffers are initialized, and only performs + * merging when they are initialized both. + * @since 2.1.0 + */ + override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = { + if (!b1._1) { + b2 + } else if (!b2._1) { + b1 + } else { + (true, func(b1._2, b2._2)) + } + } + + /** + * Transform the output of the reduction. Simply output the value in the buffer. + * @since 2.1.0 + */ + override def finish(reduction: (Boolean, T)): T = { + reduction._2 + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index ddc4dcd2395b..535a6c7d21ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.language.postfixOps import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator} import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -314,4 +314,42 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { val ds3 = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData] assert(ds3.select(NameAgg.toColumn).schema.head.nullable === true) } + + test("ReduceAggregator: zero value") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder) + assert(aggregator.zero == (false, null)) + } + + test("ReduceAggregator: reduce, merge and finish") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder) + + val firstReduce = aggregator.reduce(aggregator.zero, 1) + assert(firstReduce == (true, 1)) + + val secondReduce = aggregator.reduce(firstReduce, 2) + assert(secondReduce == (true, 3)) + + val thirdReduce = aggregator.reduce(secondReduce, 3) + assert(thirdReduce == (true, 6)) + + val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce) + assert(mergeWithZero1 == (true, 1)) + + val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero) + assert(mergeWithZero2 == (true, 3)) + + val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce) + assert(mergeTwoReduced == (true, 4)) + + assert(aggregator.finish(firstReduce)== 1) + assert(aggregator.finish(secondReduce) == 3) + assert(aggregator.finish(thirdReduce) == 6) + assert(aggregator.finish(mergeWithZero1) == 1) + assert(aggregator.finish(mergeWithZero2) == 3) + assert(aggregator.finish(mergeTwoReduced) == 4) + } }