Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ou
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator

/**
* :: Experimental ::
Expand Down Expand Up @@ -177,10 +178,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 1.6.0
*/
def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
val encoder = encoderFor[V]
val aggregator: TypedColumn[V, V] = new ReduceAggregator(f, encoder).toColumn

implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc)
flatMapGroups(func)
agg(aggregator)
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.expressions

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

/**
* :: Experimental ::
* A generic class for reduce aggregations, which accepts a reduce function that can be used to take
* all of the elements of a group and reduce them to a single value.
*
* @tparam T The input and output type for the reduce function.
* @param func The reduce aggregation function.
* @param encoder The encoder for the input and output type of the reduce function.
* @since 2.1.0
*/
@Experimental
private[sql] class ReduceAggregator[T](func: (T, T) => T, encoder: ExpressionEncoder[T])
extends Aggregator[T, (Boolean, T), T] {

/**
* A zero value for this aggregation. It is represented as a Tuple2. The first element of the
* tuple is a false boolean value indicating the buffer is not initialized. The second element
* is initialized as a null value.
* @since 2.1.0
*/
override def zero: (Boolean, T) = (false, null.asInstanceOf[T])

override def bufferEncoder: Encoder[(Boolean, T)] =
ExpressionEncoder.tuple(ExpressionEncoder[Boolean](), encoder)

override def outputEncoder: Encoder[T] = encoder

/**
* Combine two values to produce a new value. If the buffer `b` is not initialized, it simply
* takes the value of `a` and set the initialization flag to `true`.
* @since 2.1.0
*/
override def reduce(b: (Boolean, T), a: T): (Boolean, T) = {
if (b._1) {
(true, func(b._2, a))
} else {
(true, a)
}
}

/**
* Merge two intermediate values. As it is possibly that the buffer is just the `zero` value
* coming from empty partition, it checks if the buffers are initialized, and only performs
* merging when they are initialized both.
* @since 2.1.0
*/
override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = {
if (!b1._1) {
b2
} else if (!b2._1) {
b1
} else {
(true, func(b1._2, b2._2))
}
}

/**
* Transform the output of the reduction. Simply output the value in the buffer.
* @since 2.1.0
*/
override def finish(reduction: (Boolean, T)): T = {
reduction._2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.language.postfixOps

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator}
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -314,4 +314,42 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
val ds3 = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData]
assert(ds3.select(NameAgg.toColumn).schema.head.nullable === true)
}

test("ReduceAggregator: zero value") {
val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
val func = (v1: Int, v2: Int) => v1 + v2
val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder)
assert(aggregator.zero == (false, null))
}

test("ReduceAggregator: reduce, merge and finish") {
val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
val func = (v1: Int, v2: Int) => v1 + v2
val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder)

val firstReduce = aggregator.reduce(aggregator.zero, 1)
assert(firstReduce == (true, 1))

val secondReduce = aggregator.reduce(firstReduce, 2)
assert(secondReduce == (true, 3))

val thirdReduce = aggregator.reduce(secondReduce, 3)
assert(thirdReduce == (true, 6))

val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce)
assert(mergeWithZero1 == (true, 1))

val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero)
assert(mergeWithZero2 == (true, 3))

val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce)
assert(mergeTwoReduced == (true, 4))

assert(aggregator.finish(firstReduce)== 1)
assert(aggregator.finish(secondReduce) == 3)
assert(aggregator.finish(thirdReduce) == 6)
assert(aggregator.finish(mergeWithZero1) == 1)
assert(aggregator.finish(mergeWithZero2) == 3)
assert(aggregator.finish(mergeTwoReduced) == 4)
}
}