Skip to content

Commit 1e80aca

Browse files
committed
Add aggregateMessages, which supersedes mapReduceTriplets
aggregateMessages enables neighborhood computation similarly to mapReduceTriplets, but it introduces two API improvements: 1. Messages are sent using an imperative interface based on EdgeContext rather than by returning an iterator of messages. This is more efficient, providing a 20.2% speedup on PageRank over #3054 (uk-2007-05 graph, 10 iterations, 16 r3.2xlarge machines, sped up from 403 s to 322 s). 2. Rather than attempting bytecode inspection, the required triplet fields must be explicitly specified by the user by passing a TripletFields object. This fixes SPARK-3936. Subsumes #2815.
1 parent 194a2df commit 1e80aca

File tree

10 files changed

+374
-160
lines changed

10 files changed

+374
-160
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.graphx
19+
20+
/**
21+
* Represents an edge along with its neighboring vertices and allows sending messages along the
22+
* edge. Used in [[Graph#aggregateMessages]].
23+
*/
24+
trait EdgeContext[VD, ED, A] {
25+
/** The vertex id of the edge's source vertex. */
26+
def srcId: VertexId
27+
/** The vertex id of the edge's destination vertex. */
28+
def dstId: VertexId
29+
/** The vertex attribute of the edge's source vertex. */
30+
def srcAttr: VD
31+
/** The vertex attribute of the edge's destination vertex. */
32+
def dstAttr: VD
33+
/** The attribute associated with the edge. */
34+
def attr: ED
35+
36+
/** Sends a message to the source vertex. */
37+
def sendToSrc(msg: A): Unit
38+
/** Sends a message to the destination vertex. */
39+
def sendToDst(msg: A): Unit
40+
41+
/** Converts the edge and vertex properties into an [[EdgeTriplet]] for convenience. */
42+
def toEdgeTriplet: EdgeTriplet[VD, ED] = {
43+
val et = new EdgeTriplet[VD, ED]
44+
et.srcId = srcId
45+
et.srcAttr = srcAttr
46+
et.dstId = dstId
47+
et.dstAttr = dstAttr
48+
et.attr = attr
49+
et
50+
}
51+
}

graphx/src/main/scala/org/apache/spark/graphx/Graph.scala

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
195195
* the underlying index structures can be reused.
196196
*
197197
* @param map the function from an edge object to a new edge value.
198+
* @param tripletFields which fields should be included in the edge triplet passed to the map
199+
* function. If not all fields are needed, specifying this can improve performance.
198200
*
199201
* @tparam ED2 the new edge data type
200202
*
@@ -207,8 +209,10 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
207209
* }}}
208210
*
209211
*/
210-
def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
211-
mapTriplets((pid, iter) => iter.map(map))
212+
def mapTriplets[ED2: ClassTag](
213+
map: EdgeTriplet[VD, ED] => ED2,
214+
tripletFields: TripletFields = TripletFields.All): Graph[VD, ED2] = {
215+
mapTriplets((pid, iter) => iter.map(map), tripletFields)
212216
}
213217

214218
/**
@@ -223,12 +227,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
223227
* the underlying index structures can be reused.
224228
*
225229
* @param map the iterator transform
230+
* @param tripletFields which fields should be included in the edge triplet passed to the map
231+
* function. If not all fields are needed, specifying this can improve performance.
226232
*
227233
* @tparam ED2 the new edge data type
228234
*
229235
*/
230-
def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2])
231-
: Graph[VD, ED2]
236+
def mapTriplets[ED2: ClassTag](
237+
map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2],
238+
tripletFields: TripletFields): Graph[VD, ED2]
232239

233240
/**
234241
* Reverses all edges in the graph. If this graph contains an edge from a to b then the returned
@@ -287,6 +294,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
287294
* "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of
288295
* the map phase destined to each vertex.
289296
*
297+
* This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead.
298+
*
290299
* @tparam A the type of "message" to be sent to each vertex
291300
*
292301
* @param mapFunc the user defined map function which returns 0 or
@@ -319,15 +328,62 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
319328
* predicate or implement PageRank.
320329
*
321330
*/
331+
@deprecated("use aggregateMessages", "1.2.0")
322332
def mapReduceTriplets[A: ClassTag](
323333
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
324334
reduceFunc: (A, A) => A,
325335
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None)
326336
: VertexRDD[A]
327337

328338
/**
329-
* Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The
330-
* input table should contain at most one entry for each vertex. If no entry in `other` is
339+
* Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied
340+
* `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be
341+
* sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages
342+
* destined to the same vertex.
343+
*
344+
* @tparam A the type of message to be sent to each vertex
345+
*
346+
* @param sendMsg runs on each edge, sending messages to neighboring vertices using the
347+
* [[EdgeContext]].
348+
* @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This
349+
* combiner should be commutative and associative.
350+
* @param tripletFields which fields should be included in the [[EdgeContext]] passed to the
351+
* `sendMsg` function. If not all fields are needed, specifying this can improve performance.
352+
* @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if
353+
* desired. This is done by specifying a set of "active" vertices and an edge direction. The
354+
* `sendMsg` function will then run on only edges connected to active vertices by edges in the
355+
* specified direction. If the direction is `In`, `sendMsg` will only be run on edges with
356+
* destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges
357+
* originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be
358+
* run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg`
359+
* will be run on edges with *both* vertices in the active set. The active set must have the
360+
* same index as the graph's vertices.
361+
*
362+
* @example We can use this function to compute the in-degree of each
363+
* vertex
364+
* {{{
365+
* val rawGraph: Graph[_, _] = Graph.textFile("twittergraph")
366+
* val inDeg: RDD[(VertexId, Int)] =
367+
* aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _)
368+
* }}}
369+
*
370+
* @note By expressing computation at the edge level we achieve
371+
* maximum parallelism. This is one of the core functions in the
372+
* Graph API in that enables neighborhood level computation. For
373+
* example this function can be used to count neighbors satisfying a
374+
* predicate or implement PageRank.
375+
*
376+
*/
377+
def aggregateMessages[A: ClassTag](
378+
sendMsg: EdgeContext[VD, ED, A] => Unit,
379+
mergeMsg: (A, A) => A,
380+
tripletFields: TripletFields = TripletFields.All,
381+
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None)
382+
: VertexRDD[A]
383+
384+
/**
385+
* Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`.
386+
* The input table should contain at most one entry for each vertex. If no entry in `other` is
331387
* provided for a particular vertex in the graph, the map function receives `None`.
332388
*
333389
* @tparam U the type of entry in the table of updates

graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
6969
*/
7070
private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = {
7171
if (edgeDirection == EdgeDirection.In) {
72-
graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _)
72+
graph.aggregateMessages(_.sendToDst(1), _ + _, TripletFields.None)
7373
} else if (edgeDirection == EdgeDirection.Out) {
74-
graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _)
74+
graph.aggregateMessages(_.sendToSrc(1), _ + _, TripletFields.None)
7575
} else { // EdgeDirection.Either
76-
graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _)
76+
graph.aggregateMessages(ctx => { ctx.sendToSrc(1); ctx.sendToDst(1) }, _ + _,
77+
TripletFields.None)
7778
}
7879
}
7980

@@ -88,18 +89,17 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
8889
def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = {
8990
val nbrs =
9091
if (edgeDirection == EdgeDirection.Either) {
91-
graph.mapReduceTriplets[Array[VertexId]](
92-
mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))),
93-
reduceFunc = _ ++ _
94-
)
92+
graph.aggregateMessages[Array[VertexId]](
93+
ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) },
94+
_ ++ _, TripletFields.None)
9595
} else if (edgeDirection == EdgeDirection.Out) {
96-
graph.mapReduceTriplets[Array[VertexId]](
97-
mapFunc = et => Iterator((et.srcId, Array(et.dstId))),
98-
reduceFunc = _ ++ _)
96+
graph.aggregateMessages[Array[VertexId]](
97+
ctx => ctx.sendToSrc(Array(ctx.dstId)),
98+
_ ++ _, TripletFields.None)
9999
} else if (edgeDirection == EdgeDirection.In) {
100-
graph.mapReduceTriplets[Array[VertexId]](
101-
mapFunc = et => Iterator((et.dstId, Array(et.srcId))),
102-
reduceFunc = _ ++ _)
100+
graph.aggregateMessages[Array[VertexId]](
101+
ctx => ctx.sendToDst(Array(ctx.srcId)),
102+
_ ++ _, TripletFields.None)
103103
} else {
104104
throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
105105
"direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
@@ -122,22 +122,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
122122
* @return the vertex set of neighboring vertex attributes for each vertex
123123
*/
124124
def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = {
125-
val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]](
126-
edge => {
127-
val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr)))
128-
val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr)))
129-
edgeDirection match {
130-
case EdgeDirection.Either => Iterator(msgToSrc, msgToDst)
131-
case EdgeDirection.In => Iterator(msgToDst)
132-
case EdgeDirection.Out => Iterator(msgToSrc)
133-
case EdgeDirection.Both =>
134-
throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" +
135-
"EdgeDirection.Either instead.")
136-
}
137-
},
138-
(a, b) => a ++ b)
139-
140-
graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
125+
val nbrs = edgeDirection match {
126+
case EdgeDirection.Either =>
127+
graph.aggregateMessages[Array[(VertexId,VD)]](
128+
ctx => {
129+
ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr)))
130+
ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr)))
131+
},
132+
(a, b) => a ++ b, TripletFields.SrcDstOnly)
133+
case EdgeDirection.In =>
134+
graph.aggregateMessages[Array[(VertexId,VD)]](
135+
ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))),
136+
(a, b) => a ++ b, TripletFields.SrcOnly)
137+
case EdgeDirection.Out =>
138+
graph.aggregateMessages[Array[(VertexId,VD)]](
139+
ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))),
140+
(a, b) => a ++ b, TripletFields.DstOnly)
141+
case EdgeDirection.Both =>
142+
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
143+
"EdgeDirection.Either instead.")
144+
}
145+
graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) =>
141146
nbrsOpt.getOrElse(Array.empty[(VertexId, VD)])
142147
}
143148
} // end of collectNeighbor
@@ -160,18 +165,20 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
160165
def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = {
161166
edgeDirection match {
162167
case EdgeDirection.Either =>
163-
graph.mapReduceTriplets[Array[Edge[ED]]](
164-
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))),
165-
(edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
166-
(a, b) => a ++ b)
168+
graph.aggregateMessages[Array[Edge[ED]]](
169+
ctx => {
170+
ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr)))
171+
ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr)))
172+
},
173+
(a, b) => a ++ b, TripletFields.EdgeOnly)
167174
case EdgeDirection.In =>
168-
graph.mapReduceTriplets[Array[Edge[ED]]](
169-
edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
170-
(a, b) => a ++ b)
175+
graph.aggregateMessages[Array[Edge[ED]]](
176+
ctx => ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))),
177+
(a, b) => a ++ b, TripletFields.EdgeOnly)
171178
case EdgeDirection.Out =>
172-
graph.mapReduceTriplets[Array[Edge[ED]]](
173-
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
174-
(a, b) => a ++ b)
179+
graph.aggregateMessages[Array[Edge[ED]]](
180+
ctx => ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))),
181+
(a, b) => a ++ b, TripletFields.EdgeOnly)
175182
case EdgeDirection.Both =>
176183
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
177184
"EdgeDirection.Either instead.")
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.graphx
19+
20+
/**
21+
* Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the
22+
* system to populate only those fields for efficiency.
23+
*/
24+
class TripletFields private (
25+
val useSrc: Boolean,
26+
val useDst: Boolean,
27+
val useEdge: Boolean)
28+
extends Serializable {
29+
private def this() = this(true, true, true)
30+
}
31+
32+
/**
33+
* Exposes all possible [[TripletFields]] objects.
34+
*/
35+
object TripletFields {
36+
final val None = new TripletFields(useSrc = false, useDst = false, useEdge = false)
37+
final val EdgeOnly = new TripletFields(useSrc = false, useDst = false, useEdge = true)
38+
final val SrcOnly = new TripletFields(useSrc = true, useDst = false, useEdge = false)
39+
final val DstOnly = new TripletFields(useSrc = false, useDst = true, useEdge = false)
40+
final val SrcDstOnly = new TripletFields(useSrc = true, useDst = true, useEdge = false)
41+
final val SrcAndEdge = new TripletFields(useSrc = true, useDst = false, useEdge = true)
42+
final val Src = SrcAndEdge
43+
final val DstAndEdge = new TripletFields(useSrc = false, useDst = true, useEdge = true)
44+
final val Dst = DstAndEdge
45+
final val All = new TripletFields(useSrc = true, useDst = true, useEdge = true)
46+
47+
/** Returns the appropriate [[TripletFields]] object. */
48+
private[graphx] def apply(useSrc: Boolean, useDst: Boolean, useEdge: Boolean) =
49+
(useSrc, useDst, useEdge) match {
50+
case (false, false, false) => TripletFields.None
51+
case (false, false, true) => EdgeOnly
52+
case (true, false, false) => SrcOnly
53+
case (false, true, false) => DstOnly
54+
case (true, true, false) => SrcDstOnly
55+
case (true, false, true) => SrcAndEdge
56+
case (false, true, true) => DstAndEdge
57+
case (true, true, true) => All
58+
}
59+
}

0 commit comments

Comments
 (0)