diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/BinaryTreeReducedRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/BinaryTreeReducedRDD.scala new file mode 100644 index 000000000000..58544b342832 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/BinaryTreeReducedRDD.scala @@ -0,0 +1,71 @@ +package org.apache.spark.mllib.rdd + +import org.apache.spark.{TaskContext, Partition, NarrowDependency} + +import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD + +/** + * Represents a binary tree dependency, where partition `i` depends on partitions `2 * i` and + * `2 * i + 1` (if it exists) of the parent RDD. + * @param rdd parent RDD + * @tparam T value type + */ +private class BinaryTreeDependency[T](@transient rdd: RDD[T]) extends NarrowDependency(rdd) { + + val n = rdd.partitions.size + + override def getParents(partitionId: Int): Seq[Int] = { + val i1 = 2 * partitionId + val i2 = i1 + 1 + if (i2 < n) { + Seq(i1, i2) + } else { + Seq(i1) + } + } +} + +private class BinaryTreeNodePartition( + override val index: Int, + val left: Partition, + val right: Option[Partition]) extends Partition { +} + +private object BinaryTreeNodePartition { + def apply(rdd: RDD[_], i: Int): Partition = { + val n = rdd.partitions.size + val i1 = 2 * i + val i2 = i1 + 1 + if (i2 < n) { + new BinaryTreeNodePartition(i, rdd.partitions(i1), Some(rdd.partitions(i2))) + } else { + new BinaryTreeNodePartition(i, rdd.partitions(i1), None) + } + } +} + +private[mllib] class BinaryTreeReducedRDD[T: ClassTag](rdd: RDD[T], f: (T, T) => T) + extends RDD[T](rdd.context, List(new BinaryTreeDependency(rdd))) { + + override protected def getPartitions: Array[Partition] = { + Array.tabulate((rdd.partitions.size + 1) / 2)(i => BinaryTreeNodePartition(rdd, i)) + } + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val p = split.asInstanceOf[BinaryTreeNodePartition] + val iterLeft = rdd.compute(p.left, context) + val iterRight = if (p.right.isDefined) rdd.compute(p.right.get, context) else Iterator.empty + val iter = iterLeft ++ iterRight + if (iter.isEmpty) { + Iterator.empty + } else { + Iterator(iter.reduce(f)) + } + } + + override protected def getPreferredLocations(split: Partition): Seq[String] = { + val p = split.asInstanceOf[BinaryTreeNodePartition] + rdd.preferredLocations(p.left) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/ButterflyReducedRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/ButterflyReducedRDD.scala new file mode 100644 index 000000000000..4f4cd303efc3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/ButterflyReducedRDD.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.rdd.RDD + +/** A partition in a butterfly-reduced RDD. */ +private case class ButterflyReducedRDDPartition( + override val index: Int, + source: Partition, + target: Partition) extends Partition + +/** + * Butterfly-reduced RDD. + */ +private[mllib] class ButterflyReducedRDD[T: ClassTag]( + @transient rdd: RDD[T], + reducer: (T, T) => T, + @transient offset: Int) extends RDD[T](rdd) { + + /** Computes the target partition. */ + private def targetPartition(i: Int): Partition = { + val j = (i + offset) % rdd.partitions.size + rdd.partitions(j) + } + + override def getPartitions: Array[Partition] = { + rdd.partitions.zipWithIndex.map { case (part, i) => + ButterflyReducedRDDPartition(i, part, targetPartition(i)) + } + } + + override def compute(s: Partition, context: TaskContext): Iterator[T] = { + val pair = s.asInstanceOf[ButterflyReducedRDDPartition] + Iterator((firstParent[T].iterator(pair.source, context) ++ + firstParent[T].iterator(pair.target, context)).reduce(reducer)) + } + + override def getPreferredLocations(s: Partition): Seq[String] = { + rdd.preferredLocations(s.asInstanceOf[ButterflyReducedRDDPartition].source) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 365b5e75d7f7..5de5da7c4b66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} /** * Machine learning specific RDD functions. @@ -44,6 +44,65 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { new SlidingRDD[T](self, windowSize) } } + + /** + * Computes the all-reduced RDD of the parent RDD, which has the same number of partitions and + * locality information as its parent RDD. Each partition contains only one record, which is the + * same as calling `RDD#reduce` on its parent RDD. + * + * @param f reducer + * @return all-reduced RDD + */ + def allReduce(f: (T, T) => T): RDD[T] = { + val numPartitions = self.partitions.size + require(numPartitions > 0, "Parent RDD does not have any partitions.") + val nextPowerOfTwo = { + var i = 0 + while ((numPartitions >> i) > 0) { + i += 1 + } + 1 << i + } + var butterfly = self.mapPartitions( (iter) => + Iterator(iter.reduce(f)), + preservesPartitioning = true + ).cache() + + if (nextPowerOfTwo > numPartitions) { + val padding = self.context.parallelize(Seq.empty[T], nextPowerOfTwo - numPartitions) + butterfly = butterfly.union(padding) + } + + var offset = nextPowerOfTwo >> 1 + while (offset > 0) { + butterfly = new ButterflyReducedRDD[T](butterfly, f, offset).cache() + offset >>= 1 + } + + if (nextPowerOfTwo > numPartitions) { + PartitionPruningRDD.create(butterfly, (i) => i < numPartitions) + } else { + butterfly + } + } + + /** + * Reduce the elements of this RDD using the binary tree algorithm. + */ + def binaryTreeReduce(f: (T, T) => T): T = { + var reduced = self.mapPartitions( (iter) => + if (iter.isEmpty) { + Iterator.empty + } else { + Iterator(iter.reduce(f)) + }, + preservesPartitioning = true + ) + while (reduced.partitions.size > 3) { + reduced = new BinaryTreeReducedRDD(reduced, f) + } + reduced.reduce(f) + } } private[mllib] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 3f3b10dfff35..3900bc081277 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -46,4 +46,24 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val expected = data.flatMap(x => x).sliding(3).toList assert(sliding.collect().toList === expected) } + + test("allReduce") { + for (numPartitions <- 1 to 10) { + val rdd = sc.parallelize(0 until 1000, numPartitions) + val sum = rdd.reduce(_ + _) + val allReduced = rdd.allReduce(_ + _) + assert(allReduced.partitions.size === numPartitions) + assert(allReduced.collect().toSeq === Iterator.fill(numPartitions)(sum).toSeq) + } + } + + test("binaryTreeReduce") { + val data = 0 until 5 + val expected = data.reduce(_ + _) + for (numPartitions <- 1 to 12) { + val rdd = sc.parallelize(data, numPartitions) + val actual = rdd.binaryTreeReduce(_ + _) + assert(actual === expected) + } + } }