Skip to content

Commit 4a566dc

Browse files
committed
Optimizations for mapReduceTriplets and EdgePartition
1. EdgePartition now stores local vertex ids instead of global ids. This avoids hash lookups when looking up vertex attributes and aggregating messages. 2. Internal iterators in mapReduceTriplets are inlined into a while loop.
1 parent 26d31d1 commit 4a566dc

File tree

7 files changed

+310
-175
lines changed

7 files changed

+310
-175
lines changed

graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala

Lines changed: 187 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.reflect.{classTag, ClassTag}
2121

2222
import org.apache.spark.graphx._
2323
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
24+
import org.apache.spark.util.collection.BitSet
2425

2526
/**
2627
* A collection of edges stored in columnar format, along with any vertex attributes referenced. The
@@ -30,54 +31,76 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
3031
* @tparam ED the edge attribute type
3132
* @tparam VD the vertex attribute type
3233
*
33-
* @param srcIds the source vertex id of each edge
34-
* @param dstIds the destination vertex id of each edge
34+
* @param localSrcIds the local source vertex id of each edge as an index into `local2global` and
35+
* `vertexAttrs`
36+
* @param localDstIds the local destination vertex id of each edge as an index into `local2global`
37+
* and `vertexAttrs`
3538
* @param data the attribute associated with each edge
36-
* @param index a clustered index on source vertex id
37-
* @param vertices a map from referenced vertex ids to their corresponding attributes. Must
38-
* contain all vertex ids from `srcIds` and `dstIds`, though not necessarily valid attributes for
39-
* those vertex ids. The mask is not used.
39+
* @param index a clustered index on source vertex id as a map from each global source vertex id to
40+
* the offset in the edge arrays where the cluster for that vertex id begins
41+
* @param global2local a map from referenced vertex ids to local ids which index into vertexAttrs
42+
* @param local2global an array of global vertex ids where the offsets are local vertex ids
43+
* @param vertexAttrs an array of vertex attributes where the offsets are local vertex ids
4044
* @param activeSet an optional active vertex set for filtering computation on the edges
4145
*/
4246
private[graphx]
4347
class EdgePartition[
4448
@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag](
45-
val srcIds: Array[VertexId] = null,
46-
val dstIds: Array[VertexId] = null,
49+
val localSrcIds: Array[Int] = null,
50+
val localDstIds: Array[Int] = null,
4751
val data: Array[ED] = null,
4852
val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null,
49-
val vertices: VertexPartition[VD] = null,
53+
val global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null,
54+
val local2global: Array[VertexId] = null,
55+
val vertexAttrs: Array[VD] = null,
5056
val activeSet: Option[VertexSet] = None
5157
) extends Serializable {
5258

5359
/** Return a new `EdgePartition` with the specified edge data. */
54-
def withData[ED2: ClassTag](data_ : Array[ED2]): EdgePartition[ED2, VD] = {
55-
new EdgePartition(srcIds, dstIds, data_, index, vertices, activeSet)
56-
}
57-
58-
/** Return a new `EdgePartition` with the specified vertex partition. */
59-
def withVertices[VD2: ClassTag](
60-
vertices_ : VertexPartition[VD2]): EdgePartition[ED, VD2] = {
61-
new EdgePartition(srcIds, dstIds, data, index, vertices_, activeSet)
60+
def withData[ED2: ClassTag](data: Array[ED2]): EdgePartition[ED2, VD] = {
61+
new EdgePartition(
62+
localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet)
6263
}
6364

6465
/** Return a new `EdgePartition` with the specified active set, provided as an iterator. */
6566
def withActiveSet(iter: Iterator[VertexId]): EdgePartition[ED, VD] = {
66-
val newActiveSet = new VertexSet
67-
iter.foreach(newActiveSet.add(_))
68-
new EdgePartition(srcIds, dstIds, data, index, vertices, Some(newActiveSet))
67+
val activeSet = new VertexSet
68+
iter.foreach(activeSet.add(_))
69+
new EdgePartition(
70+
localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs,
71+
Some(activeSet))
6972
}
7073

7174
/** Return a new `EdgePartition` with the specified active set. */
72-
def withActiveSet(activeSet_ : Option[VertexSet]): EdgePartition[ED, VD] = {
73-
new EdgePartition(srcIds, dstIds, data, index, vertices, activeSet_)
75+
def withActiveSet(activeSet: Option[VertexSet]): EdgePartition[ED, VD] = {
76+
new EdgePartition(
77+
localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet)
7478
}
7579

7680
/** Return a new `EdgePartition` with updates to vertex attributes specified in `iter`. */
7781
def updateVertices(iter: Iterator[(VertexId, VD)]): EdgePartition[ED, VD] = {
78-
this.withVertices(vertices.innerJoinKeepLeft(iter))
82+
val newVertexAttrs = new Array[VD](vertexAttrs.length)
83+
System.arraycopy(vertexAttrs, 0, newVertexAttrs, 0, vertexAttrs.length)
84+
iter.foreach { kv =>
85+
newVertexAttrs(global2local(kv._1)) = kv._2
86+
}
87+
new EdgePartition(
88+
localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs,
89+
activeSet)
90+
}
91+
92+
/** Return a new `EdgePartition` without any locally cached vertex attributes. */
93+
def clearVertices[VD2: ClassTag](): EdgePartition[ED, VD2] = {
94+
val newVertexAttrs = new Array[VD2](vertexAttrs.length)
95+
new EdgePartition(
96+
localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs,
97+
activeSet)
7998
}
8099

100+
def srcIds(i: Int): VertexId = local2global(localSrcIds(i))
101+
102+
def dstIds(i: Int): VertexId = local2global(localDstIds(i))
103+
81104
/** Look up vid in activeSet, throwing an exception if it is None. */
82105
def isActive(vid: VertexId): Boolean = {
83106
activeSet.get.contains(vid)
@@ -92,11 +115,19 @@ class EdgePartition[
92115
* @return a new edge partition with all edges reversed.
93116
*/
94117
def reverse: EdgePartition[ED, VD] = {
95-
val builder = new EdgePartitionBuilder(size)(classTag[ED], classTag[VD])
96-
for (e <- iterator) {
97-
builder.add(e.dstId, e.srcId, e.attr)
118+
val builder = new VertexPreservingEdgePartitionBuilder(
119+
global2local, local2global, vertexAttrs, size)(classTag[ED], classTag[VD])
120+
var i = 0
121+
while (i < size) {
122+
val localSrcId = localSrcIds(i)
123+
val localDstId = localDstIds(i)
124+
val srcId = local2global(localSrcId)
125+
val dstId = local2global(localDstId)
126+
val attr = data(i)
127+
builder.add(dstId, srcId, localDstId, localSrcId, attr)
128+
i += 1
98129
}
99-
builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
130+
builder.toEdgePartition.withActiveSet(activeSet)
100131
}
101132

102133
/**
@@ -157,13 +188,25 @@ class EdgePartition[
157188
def filter(
158189
epred: EdgeTriplet[VD, ED] => Boolean,
159190
vpred: (VertexId, VD) => Boolean): EdgePartition[ED, VD] = {
160-
val filtered = tripletIterator().filter(et =>
161-
vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et))
162-
val builder = new EdgePartitionBuilder[ED, VD]
163-
for (e <- filtered) {
164-
builder.add(e.srcId, e.dstId, e.attr)
191+
val builder = new VertexPreservingEdgePartitionBuilder[ED, VD](
192+
global2local, local2global, vertexAttrs)
193+
var i = 0
194+
while (i < size) {
195+
// The user sees the EdgeTriplet, so we can't reuse it and must create one per edge.
196+
val localSrcId = localSrcIds(i)
197+
val localDstId = localDstIds(i)
198+
val et = new EdgeTriplet[VD, ED]
199+
et.srcId = local2global(localSrcId)
200+
et.dstId = local2global(localDstId)
201+
et.srcAttr = vertexAttrs(localSrcId)
202+
et.dstAttr = vertexAttrs(localDstId)
203+
et.attr = data(i)
204+
if (vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) {
205+
builder.add(et.srcId, et.dstId, localSrcId, localDstId, et.attr)
206+
}
207+
i += 1
165208
}
166-
builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
209+
builder.toEdgePartition.withActiveSet(activeSet)
167210
}
168211

169212
/**
@@ -183,7 +226,8 @@ class EdgePartition[
183226
* @return a new edge partition without duplicate edges
184227
*/
185228
def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED, VD] = {
186-
val builder = new EdgePartitionBuilder[ED, VD]
229+
val builder = new VertexPreservingEdgePartitionBuilder[ED, VD](
230+
global2local, local2global, vertexAttrs)
187231
var currSrcId: VertexId = null.asInstanceOf[VertexId]
188232
var currDstId: VertexId = null.asInstanceOf[VertexId]
189233
var currAttr: ED = null.asInstanceOf[ED]
@@ -193,7 +237,7 @@ class EdgePartition[
193237
currAttr = merge(currAttr, data(i))
194238
} else {
195239
if (i > 0) {
196-
builder.add(currSrcId, currDstId, currAttr)
240+
builder.add(currSrcId, currDstId, localSrcIds(i - 1), localDstIds(i - 1), currAttr)
197241
}
198242
currSrcId = srcIds(i)
199243
currDstId = dstIds(i)
@@ -202,9 +246,9 @@ class EdgePartition[
202246
i += 1
203247
}
204248
if (size > 0) {
205-
builder.add(currSrcId, currDstId, currAttr)
249+
builder.add(currSrcId, currDstId, localSrcIds(i - 1), localDstIds(i - 1), currAttr)
206250
}
207-
builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
251+
builder.toEdgePartition.withActiveSet(activeSet)
208252
}
209253

210254
/**
@@ -220,7 +264,8 @@ class EdgePartition[
220264
def innerJoin[ED2: ClassTag, ED3: ClassTag]
221265
(other: EdgePartition[ED2, _])
222266
(f: (VertexId, VertexId, ED, ED2) => ED3): EdgePartition[ED3, VD] = {
223-
val builder = new EdgePartitionBuilder[ED3, VD]
267+
val builder = new VertexPreservingEdgePartitionBuilder[ED3, VD](
268+
global2local, local2global, vertexAttrs)
224269
var i = 0
225270
var j = 0
226271
// For i = index of each edge in `this`...
@@ -233,20 +278,21 @@ class EdgePartition[
233278
while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 }
234279
if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) {
235280
// ... run `f` on the matching edge
236-
builder.add(srcId, dstId, f(srcId, dstId, this.data(i), other.data(j)))
281+
builder.add(srcId, dstId, localSrcIds(i), localDstIds(i),
282+
f(srcId, dstId, this.data(i), other.data(j)))
237283
}
238284
}
239285
i += 1
240286
}
241-
builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
287+
builder.toEdgePartition.withActiveSet(activeSet)
242288
}
243289

244290
/**
245291
* The number of edges in this partition
246292
*
247293
* @return size of the partition
248294
*/
249-
val size: Int = srcIds.size
295+
val size: Int = localSrcIds.size
250296

251297
/** The number of unique source vertices in the partition. */
252298
def indexSize: Int = index.size
@@ -285,50 +331,116 @@ class EdgePartition[
285331
}
286332

287333
/**
288-
* Upgrade the given edge iterator into a triplet iterator.
334+
* Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning
335+
* all edges sequentially and filtering them with `idPred`.
336+
*
337+
* @param mapFunc the edge map function which generates messages to neighboring vertices
338+
* @param reduceFunc the combiner applied to messages destined to the same vertex
339+
* @param mapUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute
340+
* @param mapUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute
341+
* @param idPred a predicate to filter edges based on their source and destination vertex ids
289342
*
290-
* Be careful not to keep references to the objects from this iterator. To improve GC performance
291-
* the same object is re-used in `next()`.
343+
* @return iterator aggregated messages keyed by the receiving vertex id
292344
*/
293-
def upgradeIterator(
294-
edgeIter: Iterator[Edge[ED]], includeSrc: Boolean = true, includeDst: Boolean = true)
295-
: Iterator[EdgeTriplet[VD, ED]] = {
296-
new ReusingEdgeTripletIterator(edgeIter, this, includeSrc, includeDst)
345+
def mapReduceTriplets[A: ClassTag](
346+
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
347+
reduceFunc: (A, A) => A,
348+
mapUsesSrcAttr: Boolean,
349+
mapUsesDstAttr: Boolean,
350+
idPred: (VertexId, VertexId) => Boolean): Iterator[(VertexId, A)] = {
351+
val aggregates = new Array[A](vertexAttrs.length)
352+
val bitset = new BitSet(vertexAttrs.length)
353+
354+
var edge = new EdgeTriplet[VD, ED]
355+
var i = 0
356+
while (i < size) {
357+
val localSrcId = localSrcIds(i)
358+
val srcId = local2global(localSrcId)
359+
val localDstId = localDstIds(i)
360+
val dstId = local2global(localDstId)
361+
if (idPred(srcId, dstId)) {
362+
edge.srcId = srcId
363+
edge.dstId = dstId
364+
edge.attr = data(i)
365+
if (mapUsesSrcAttr) { edge.srcAttr = vertexAttrs(localSrcId) }
366+
if (mapUsesDstAttr) { edge.dstAttr = vertexAttrs(localDstId) }
367+
368+
mapFunc(edge).foreach { kv =>
369+
val globalId = kv._1
370+
val msg = kv._2
371+
val localId = if (globalId == srcId) localSrcId else localDstId
372+
if (bitset.get(localId)) {
373+
aggregates(localId) = reduceFunc(aggregates(localId), msg)
374+
} else {
375+
aggregates(localId) = msg
376+
bitset.set(localId)
377+
}
378+
}
379+
}
380+
i += 1
381+
}
382+
383+
bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) }
297384
}
298385

299386
/**
300-
* Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The
301-
* iterator is generated using an index scan, so it is efficient at skipping edges that don't
302-
* match srcIdPred.
387+
* Send messages along edges and aggregate them at the receiving vertices. Implemented by
388+
* filtering the source vertex index with `srcIdPred`, then scanning edge clusters and filtering
389+
* with `dstIdPred`. Both `srcIdPred` and `dstIdPred` must match for an edge to run.
303390
*
304-
* Be careful not to keep references to the objects from this iterator. To improve GC performance
305-
* the same object is re-used in `next()`.
306-
*/
307-
def indexIterator(srcIdPred: VertexId => Boolean): Iterator[Edge[ED]] =
308-
index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator))
309-
310-
/**
311-
* Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The
312-
* cluster must start at position `index`.
391+
* @param mapFunc the edge map function which generates messages to neighboring vertices
392+
* @param reduceFunc the combiner applied to messages destined to the same vertex
393+
* @param mapUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute
394+
* @param mapUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute
395+
* @param srcIdPred a predicate to filter edges based on their source vertex id
396+
* @param dstIdPred a predicate to filter edges based on their destination vertex id
313397
*
314-
* Be careful not to keep references to the objects from this iterator. To improve GC performance
315-
* the same object is re-used in `next()`.
398+
* @return iterator aggregated messages keyed by the receiving vertex id
316399
*/
317-
private def clusterIterator(srcId: VertexId, index: Int) = new Iterator[Edge[ED]] {
318-
private[this] val edge = new Edge[ED]
319-
private[this] var pos = index
400+
def mapReduceTripletsWithIndex[A: ClassTag](
401+
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
402+
reduceFunc: (A, A) => A,
403+
mapUsesSrcAttr: Boolean,
404+
mapUsesDstAttr: Boolean,
405+
srcIdPred: VertexId => Boolean,
406+
dstIdPred: VertexId => Boolean): Iterator[(VertexId, A)] = {
407+
val aggregates = new Array[A](vertexAttrs.length)
408+
val bitset = new BitSet(vertexAttrs.length)
320409

321-
override def hasNext: Boolean = {
322-
pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId
323-
}
410+
var edge = new EdgeTriplet[VD, ED]
411+
index.iterator.foreach { cluster =>
412+
val clusterSrcId = cluster._1
413+
val clusterPos = cluster._2
414+
val clusterLocalSrcId = localSrcIds(clusterPos)
415+
if (srcIdPred(clusterSrcId)) {
416+
var pos = clusterPos
417+
edge.srcId = clusterSrcId
418+
if (mapUsesSrcAttr) { edge.srcAttr = vertexAttrs(clusterLocalSrcId) }
419+
while (pos < size && localSrcIds(pos) == clusterLocalSrcId) {
420+
val localDstId = localDstIds(pos)
421+
val dstId = local2global(localDstId)
422+
if (dstIdPred(dstId)) {
423+
edge.dstId = dstId
424+
edge.attr = data(pos)
425+
if (mapUsesDstAttr) { edge.dstAttr = vertexAttrs(localDstId) }
324426

325-
override def next(): Edge[ED] = {
326-
assert(srcIds(pos) == srcId)
327-
edge.srcId = srcIds(pos)
328-
edge.dstId = dstIds(pos)
329-
edge.attr = data(pos)
330-
pos += 1
331-
edge
427+
mapFunc(edge).foreach { kv =>
428+
val globalId = kv._1
429+
val msg = kv._2
430+
val localId = if (globalId == clusterSrcId) clusterLocalSrcId else localDstId
431+
if (bitset.get(localId)) {
432+
aggregates(localId) = reduceFunc(aggregates(localId), msg)
433+
} else {
434+
aggregates(localId) = msg
435+
bitset.set(localId)
436+
}
437+
}
438+
}
439+
pos += 1
440+
}
441+
}
332442
}
443+
444+
bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) }
333445
}
334446
}

0 commit comments

Comments
 (0)