Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.graphx
import scala.reflect.ClassTag
import org.apache.spark.Logging


/**
* Implements a Pregel-like bulk-synchronous message-passing API.
*
Expand Down Expand Up @@ -106,22 +105,27 @@ object Pregel extends Logging {
* A. ''This function must be commutative and associative and
* ideally the size of A should not increase.''
*
* @param tripletFields which fields should be included in the edge
* triplet passed to the map function. If not all fields are needed,
* specifying this can improve performance.
*
* @return the resulting graph at the end of the computation
*
*/
def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
(graph: Graph[VD, ED],
initialMsg: A,
maxIterations: Int = Int.MaxValue,
activeDirection: EdgeDirection = EdgeDirection.Either)
maxIterations: Int,
activeDirection: EdgeDirection,
tripletFields: TripletFields)
(vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A)
: Graph[VD, ED] =
{
var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
// compute the messages
var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
var messages = mapReduceTriplets(g, sendMsg, mergeMsg, tripletFields)
var activeMessages = messages.count()
// Loop
var prevG: Graph[VD, ED] = null
Expand All @@ -138,7 +142,8 @@ object Pregel extends Logging {
// Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't
// get to send messages. We must cache messages so it can be materialized on the next line,
// allowing us to uncache the previous iteration.
messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache()
messages = mapReduceTriplets(
g, sendMsg, mergeMsg, tripletFields, Some((newVerts, activeDirection))).cache()
// The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This
// hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
// vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
Expand All @@ -158,4 +163,45 @@ object Pregel extends Logging {
g
} // end of apply

/**
* This old Pregel API (<=1.2.0) is left because of
* binary compatibility.
*/
def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
(graph: Graph[VD, ED],
initialMsg: A,
maxIterations: Int = Int.MaxValue,
activeDirection: EdgeDirection = EdgeDirection.Either)
(vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A)
: Graph[VD, ED] =
{
Pregel(graph, initialMsg, maxIterations, activeDirection, TripletFields.All)(
vprog, sendMsg, mergeMsg)
}

private def mapReduceTriplets[VD: ClassTag, ED: ClassTag, A: ClassTag](
g: Graph[VD, ED],
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
reduceFunc: (A, A) => A,
tripletFields: TripletFields,
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None)
: VertexRDD[A] = {

def sendMsg(ctx: EdgeContext[VD, ED, A]) {
mapFunc(ctx.toEdgeTriplet).foreach { kv =>
val id = kv._1
val msg = kv._2
if (id == ctx.srcId) {
ctx.sendToSrc(msg)
} else {
assert(id == ctx.dstId)
ctx.sendToDst(msg)
}
}
}
g.aggregateMessagesWithActiveSet(
sendMsg, reduceFunc, tripletFields, activeSetOpt)
}
} // end of class Pregel