diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 5e55620147df..28fc2cc4861b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -20,7 +20,6 @@ package org.apache.spark.graphx import scala.reflect.ClassTag import org.apache.spark.Logging - /** * Implements a Pregel-like bulk-synchronous message-passing API. * @@ -106,14 +105,19 @@ object Pregel extends Logging { * A. ''This function must be commutative and associative and * ideally the size of A should not increase.'' * + * @param tripletFields which fields should be included in the edge + * triplet passed to the map function. If not all fields are needed, + * specifying this can improve performance. + * * @return the resulting graph at the end of the computation * */ def apply[VD: ClassTag, ED: ClassTag, A: ClassTag] (graph: Graph[VD, ED], initialMsg: A, - maxIterations: Int = Int.MaxValue, - activeDirection: EdgeDirection = EdgeDirection.Either) + maxIterations: Int, + activeDirection: EdgeDirection, + tripletFields: TripletFields) (vprog: (VertexId, VD, A) => VD, sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], mergeMsg: (A, A) => A) @@ -121,7 +125,7 @@ object Pregel extends Logging { { var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() // compute the messages - var messages = g.mapReduceTriplets(sendMsg, mergeMsg) + var messages = mapReduceTriplets(g, sendMsg, mergeMsg, tripletFields) var activeMessages = messages.count() // Loop var prevG: Graph[VD, ED] = null @@ -138,7 +142,8 @@ object Pregel extends Logging { // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't // get to send messages. We must cache messages so it can be materialized on the next line, // allowing us to uncache the previous iteration. - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache() + messages = mapReduceTriplets( + g, sendMsg, mergeMsg, tripletFields, Some((newVerts, activeDirection))).cache() // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). @@ -158,4 +163,45 @@ object Pregel extends Logging { g } // end of apply + /** + * This old Pregel API (<=1.2.0) is left because of + * binary compatibility. + */ + def apply[VD: ClassTag, ED: ClassTag, A: ClassTag] + (graph: Graph[VD, ED], + initialMsg: A, + maxIterations: Int = Int.MaxValue, + activeDirection: EdgeDirection = EdgeDirection.Either) + (vprog: (VertexId, VD, A) => VD, + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], + mergeMsg: (A, A) => A) + : Graph[VD, ED] = + { + Pregel(graph, initialMsg, maxIterations, activeDirection, TripletFields.All)( + vprog, sendMsg, mergeMsg) + } + + private def mapReduceTriplets[VD: ClassTag, ED: ClassTag, A: ClassTag]( + g: Graph[VD, ED], + mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], + reduceFunc: (A, A) => A, + tripletFields: TripletFields, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) + : VertexRDD[A] = { + + def sendMsg(ctx: EdgeContext[VD, ED, A]) { + mapFunc(ctx.toEdgeTriplet).foreach { kv => + val id = kv._1 + val msg = kv._2 + if (id == ctx.srcId) { + ctx.sendToSrc(msg) + } else { + assert(id == ctx.dstId) + ctx.sendToDst(msg) + } + } + } + g.aggregateMessagesWithActiveSet( + sendMsg, reduceFunc, tripletFields, activeSetOpt) + } } // end of class Pregel