@@ -21,6 +21,7 @@ import scala.reflect.{classTag, ClassTag}
2121
2222import org .apache .spark .graphx ._
2323import org .apache .spark .graphx .util .collection .GraphXPrimitiveKeyOpenHashMap
24+ import org .apache .spark .util .collection .BitSet
2425
2526/**
2627 * A collection of edges stored in columnar format, along with any vertex attributes referenced. The
@@ -30,54 +31,76 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
3031 * @tparam ED the edge attribute type
3132 * @tparam VD the vertex attribute type
3233 *
33- * @param srcIds the source vertex id of each edge
34- * @param dstIds the destination vertex id of each edge
34+ * @param localSrcIds the local source vertex id of each edge as an index into `local2global` and
35+ * `vertexAttrs`
36+ * @param localDstIds the local destination vertex id of each edge as an index into `local2global`
37+ * and `vertexAttrs`
3538 * @param data the attribute associated with each edge
36- * @param index a clustered index on source vertex id
37- * @param vertices a map from referenced vertex ids to their corresponding attributes. Must
38- * contain all vertex ids from `srcIds` and `dstIds`, though not necessarily valid attributes for
39- * those vertex ids. The mask is not used.
39+ * @param index a clustered index on source vertex id as a map from each global source vertex id to
40+ * the offset in the edge arrays where the cluster for that vertex id begins
41+ * @param global2local a map from referenced vertex ids to local ids which index into vertexAttrs
42+ * @param local2global an array of global vertex ids where the offsets are local vertex ids
43+ * @param vertexAttrs an array of vertex attributes where the offsets are local vertex ids
4044 * @param activeSet an optional active vertex set for filtering computation on the edges
4145 */
4246private [graphx]
4347class EdgePartition [
4448 @ specialized(Char , Int , Boolean , Byte , Long , Float , Double ) ED : ClassTag , VD : ClassTag ](
45- val srcIds : Array [VertexId ] = null ,
46- val dstIds : Array [VertexId ] = null ,
49+ val localSrcIds : Array [Int ] = null ,
50+ val localDstIds : Array [Int ] = null ,
4751 val data : Array [ED ] = null ,
4852 val index : GraphXPrimitiveKeyOpenHashMap [VertexId , Int ] = null ,
49- val vertices : VertexPartition [VD ] = null ,
53+ val global2local : GraphXPrimitiveKeyOpenHashMap [VertexId , Int ] = null ,
54+ val local2global : Array [VertexId ] = null ,
55+ val vertexAttrs : Array [VD ] = null ,
5056 val activeSet : Option [VertexSet ] = None
5157 ) extends Serializable {
5258
5359 /** Return a new `EdgePartition` with the specified edge data. */
54- def withData [ED2 : ClassTag ](data_ : Array [ED2 ]): EdgePartition [ED2 , VD ] = {
55- new EdgePartition (srcIds, dstIds, data_, index, vertices, activeSet)
56- }
57-
58- /** Return a new `EdgePartition` with the specified vertex partition. */
59- def withVertices [VD2 : ClassTag ](
60- vertices_ : VertexPartition [VD2 ]): EdgePartition [ED , VD2 ] = {
61- new EdgePartition (srcIds, dstIds, data, index, vertices_, activeSet)
60+ def withData [ED2 : ClassTag ](data : Array [ED2 ]): EdgePartition [ED2 , VD ] = {
61+ new EdgePartition (
62+ localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet)
6263 }
6364
6465 /** Return a new `EdgePartition` with the specified active set, provided as an iterator. */
6566 def withActiveSet (iter : Iterator [VertexId ]): EdgePartition [ED , VD ] = {
66- val newActiveSet = new VertexSet
67- iter.foreach(newActiveSet.add(_))
68- new EdgePartition (srcIds, dstIds, data, index, vertices, Some (newActiveSet))
67+ val activeSet = new VertexSet
68+ iter.foreach(activeSet.add(_))
69+ new EdgePartition (
70+ localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs,
71+ Some (activeSet))
6972 }
7073
7174 /** Return a new `EdgePartition` with the specified active set. */
72- def withActiveSet (activeSet_ : Option [VertexSet ]): EdgePartition [ED , VD ] = {
73- new EdgePartition (srcIds, dstIds, data, index, vertices, activeSet_)
75+ def withActiveSet (activeSet : Option [VertexSet ]): EdgePartition [ED , VD ] = {
76+ new EdgePartition (
77+ localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet)
7478 }
7579
7680 /** Return a new `EdgePartition` with updates to vertex attributes specified in `iter`. */
7781 def updateVertices (iter : Iterator [(VertexId , VD )]): EdgePartition [ED , VD ] = {
78- this .withVertices(vertices.innerJoinKeepLeft(iter))
82+ val newVertexAttrs = new Array [VD ](vertexAttrs.length)
83+ System .arraycopy(vertexAttrs, 0 , newVertexAttrs, 0 , vertexAttrs.length)
84+ iter.foreach { kv =>
85+ newVertexAttrs(global2local(kv._1)) = kv._2
86+ }
87+ new EdgePartition (
88+ localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs,
89+ activeSet)
90+ }
91+
92+ /** Return a new `EdgePartition` without any locally cached vertex attributes. */
93+ def clearVertices [VD2 : ClassTag ](): EdgePartition [ED , VD2 ] = {
94+ val newVertexAttrs = new Array [VD2 ](vertexAttrs.length)
95+ new EdgePartition (
96+ localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs,
97+ activeSet)
7998 }
8099
100+ def srcIds (i : Int ): VertexId = local2global(localSrcIds(i))
101+
102+ def dstIds (i : Int ): VertexId = local2global(localDstIds(i))
103+
81104 /** Look up vid in activeSet, throwing an exception if it is None. */
82105 def isActive (vid : VertexId ): Boolean = {
83106 activeSet.get.contains(vid)
@@ -92,11 +115,19 @@ class EdgePartition[
92115 * @return a new edge partition with all edges reversed.
93116 */
94117 def reverse : EdgePartition [ED , VD ] = {
95- val builder = new EdgePartitionBuilder (size)(classTag[ED ], classTag[VD ])
96- for (e <- iterator) {
97- builder.add(e.dstId, e.srcId, e.attr)
118+ val builder = new VertexPreservingEdgePartitionBuilder (
119+ global2local, local2global, vertexAttrs, size)(classTag[ED ], classTag[VD ])
120+ var i = 0
121+ while (i < size) {
122+ val localSrcId = localSrcIds(i)
123+ val localDstId = localDstIds(i)
124+ val srcId = local2global(localSrcId)
125+ val dstId = local2global(localDstId)
126+ val attr = data(i)
127+ builder.add(dstId, srcId, localDstId, localSrcId, attr)
128+ i += 1
98129 }
99- builder.toEdgePartition.withVertices(vertices). withActiveSet(activeSet)
130+ builder.toEdgePartition.withActiveSet(activeSet)
100131 }
101132
102133 /**
@@ -157,13 +188,25 @@ class EdgePartition[
157188 def filter (
158189 epred : EdgeTriplet [VD , ED ] => Boolean ,
159190 vpred : (VertexId , VD ) => Boolean ): EdgePartition [ED , VD ] = {
160- val filtered = tripletIterator().filter(et =>
161- vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et))
162- val builder = new EdgePartitionBuilder [ED , VD ]
163- for (e <- filtered) {
164- builder.add(e.srcId, e.dstId, e.attr)
191+ val builder = new VertexPreservingEdgePartitionBuilder [ED , VD ](
192+ global2local, local2global, vertexAttrs)
193+ var i = 0
194+ while (i < size) {
195+ // The user sees the EdgeTriplet, so we can't reuse it and must create one per edge.
196+ val localSrcId = localSrcIds(i)
197+ val localDstId = localDstIds(i)
198+ val et = new EdgeTriplet [VD , ED ]
199+ et.srcId = local2global(localSrcId)
200+ et.dstId = local2global(localDstId)
201+ et.srcAttr = vertexAttrs(localSrcId)
202+ et.dstAttr = vertexAttrs(localDstId)
203+ et.attr = data(i)
204+ if (vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) {
205+ builder.add(et.srcId, et.dstId, localSrcId, localDstId, et.attr)
206+ }
207+ i += 1
165208 }
166- builder.toEdgePartition.withVertices(vertices). withActiveSet(activeSet)
209+ builder.toEdgePartition.withActiveSet(activeSet)
167210 }
168211
169212 /**
@@ -183,7 +226,8 @@ class EdgePartition[
183226 * @return a new edge partition without duplicate edges
184227 */
185228 def groupEdges (merge : (ED , ED ) => ED ): EdgePartition [ED , VD ] = {
186- val builder = new EdgePartitionBuilder [ED , VD ]
229+ val builder = new VertexPreservingEdgePartitionBuilder [ED , VD ](
230+ global2local, local2global, vertexAttrs)
187231 var currSrcId : VertexId = null .asInstanceOf [VertexId ]
188232 var currDstId : VertexId = null .asInstanceOf [VertexId ]
189233 var currAttr : ED = null .asInstanceOf [ED ]
@@ -193,7 +237,7 @@ class EdgePartition[
193237 currAttr = merge(currAttr, data(i))
194238 } else {
195239 if (i > 0 ) {
196- builder.add(currSrcId, currDstId, currAttr)
240+ builder.add(currSrcId, currDstId, localSrcIds(i - 1 ), localDstIds(i - 1 ), currAttr)
197241 }
198242 currSrcId = srcIds(i)
199243 currDstId = dstIds(i)
@@ -202,9 +246,9 @@ class EdgePartition[
202246 i += 1
203247 }
204248 if (size > 0 ) {
205- builder.add(currSrcId, currDstId, currAttr)
249+ builder.add(currSrcId, currDstId, localSrcIds(i - 1 ), localDstIds(i - 1 ), currAttr)
206250 }
207- builder.toEdgePartition.withVertices(vertices). withActiveSet(activeSet)
251+ builder.toEdgePartition.withActiveSet(activeSet)
208252 }
209253
210254 /**
@@ -220,7 +264,8 @@ class EdgePartition[
220264 def innerJoin [ED2 : ClassTag , ED3 : ClassTag ]
221265 (other : EdgePartition [ED2 , _])
222266 (f : (VertexId , VertexId , ED , ED2 ) => ED3 ): EdgePartition [ED3 , VD ] = {
223- val builder = new EdgePartitionBuilder [ED3 , VD ]
267+ val builder = new VertexPreservingEdgePartitionBuilder [ED3 , VD ](
268+ global2local, local2global, vertexAttrs)
224269 var i = 0
225270 var j = 0
226271 // For i = index of each edge in `this`...
@@ -233,20 +278,21 @@ class EdgePartition[
233278 while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 }
234279 if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) {
235280 // ... run `f` on the matching edge
236- builder.add(srcId, dstId, f(srcId, dstId, this .data(i), other.data(j)))
281+ builder.add(srcId, dstId, localSrcIds(i), localDstIds(i),
282+ f(srcId, dstId, this .data(i), other.data(j)))
237283 }
238284 }
239285 i += 1
240286 }
241- builder.toEdgePartition.withVertices(vertices). withActiveSet(activeSet)
287+ builder.toEdgePartition.withActiveSet(activeSet)
242288 }
243289
244290 /**
245291 * The number of edges in this partition
246292 *
247293 * @return size of the partition
248294 */
249- val size : Int = srcIds .size
295+ val size : Int = localSrcIds .size
250296
251297 /** The number of unique source vertices in the partition. */
252298 def indexSize : Int = index.size
@@ -285,50 +331,116 @@ class EdgePartition[
285331 }
286332
287333 /**
288- * Upgrade the given edge iterator into a triplet iterator.
334+ * Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning
335+ * all edges sequentially and filtering them with `idPred`.
336+ *
337+ * @param mapFunc the edge map function which generates messages to neighboring vertices
338+ * @param reduceFunc the combiner applied to messages destined to the same vertex
339+ * @param mapUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute
340+ * @param mapUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute
341+ * @param idPred a predicate to filter edges based on their source and destination vertex ids
289342 *
290- * Be careful not to keep references to the objects from this iterator. To improve GC performance
291- * the same object is re-used in `next()`.
343+ * @return iterator aggregated messages keyed by the receiving vertex id
292344 */
293- def upgradeIterator (
294- edgeIter : Iterator [Edge [ED ]], includeSrc : Boolean = true , includeDst : Boolean = true )
295- : Iterator [EdgeTriplet [VD , ED ]] = {
296- new ReusingEdgeTripletIterator (edgeIter, this , includeSrc, includeDst)
345+ def mapReduceTriplets [A : ClassTag ](
346+ mapFunc : EdgeTriplet [VD , ED ] => Iterator [(VertexId , A )],
347+ reduceFunc : (A , A ) => A ,
348+ mapUsesSrcAttr : Boolean ,
349+ mapUsesDstAttr : Boolean ,
350+ idPred : (VertexId , VertexId ) => Boolean ): Iterator [(VertexId , A )] = {
351+ val aggregates = new Array [A ](vertexAttrs.length)
352+ val bitset = new BitSet (vertexAttrs.length)
353+
354+ var edge = new EdgeTriplet [VD , ED ]
355+ var i = 0
356+ while (i < size) {
357+ val localSrcId = localSrcIds(i)
358+ val srcId = local2global(localSrcId)
359+ val localDstId = localDstIds(i)
360+ val dstId = local2global(localDstId)
361+ if (idPred(srcId, dstId)) {
362+ edge.srcId = srcId
363+ edge.dstId = dstId
364+ edge.attr = data(i)
365+ if (mapUsesSrcAttr) { edge.srcAttr = vertexAttrs(localSrcId) }
366+ if (mapUsesDstAttr) { edge.dstAttr = vertexAttrs(localDstId) }
367+
368+ mapFunc(edge).foreach { kv =>
369+ val globalId = kv._1
370+ val msg = kv._2
371+ val localId = if (globalId == srcId) localSrcId else localDstId
372+ if (bitset.get(localId)) {
373+ aggregates(localId) = reduceFunc(aggregates(localId), msg)
374+ } else {
375+ aggregates(localId) = msg
376+ bitset.set(localId)
377+ }
378+ }
379+ }
380+ i += 1
381+ }
382+
383+ bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) }
297384 }
298385
299386 /**
300- * Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The
301- * iterator is generated using an index scan, so it is efficient at skipping edges that don't
302- * match srcIdPred.
387+ * Send messages along edges and aggregate them at the receiving vertices. Implemented by
388+ * filtering the source vertex index with `srcIdPred`, then scanning edge clusters and filtering
389+ * with `dstIdPred`. Both ` srcIdPred` and `dstIdPred` must match for an edge to run .
303390 *
304- * Be careful not to keep references to the objects from this iterator. To improve GC performance
305- * the same object is re-used in `next()`.
306- */
307- def indexIterator (srcIdPred : VertexId => Boolean ): Iterator [Edge [ED ]] =
308- index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function .tupled(clusterIterator))
309-
310- /**
311- * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The
312- * cluster must start at position `index`.
391+ * @param mapFunc the edge map function which generates messages to neighboring vertices
392+ * @param reduceFunc the combiner applied to messages destined to the same vertex
393+ * @param mapUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute
394+ * @param mapUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute
395+ * @param srcIdPred a predicate to filter edges based on their source vertex id
396+ * @param dstIdPred a predicate to filter edges based on their destination vertex id
313397 *
314- * Be careful not to keep references to the objects from this iterator. To improve GC performance
315- * the same object is re-used in `next()`.
398+ * @return iterator aggregated messages keyed by the receiving vertex id
316399 */
317- private def clusterIterator (srcId : VertexId , index : Int ) = new Iterator [Edge [ED ]] {
318- private [this ] val edge = new Edge [ED ]
319- private [this ] var pos = index
400+ def mapReduceTripletsWithIndex [A : ClassTag ](
401+ mapFunc : EdgeTriplet [VD , ED ] => Iterator [(VertexId , A )],
402+ reduceFunc : (A , A ) => A ,
403+ mapUsesSrcAttr : Boolean ,
404+ mapUsesDstAttr : Boolean ,
405+ srcIdPred : VertexId => Boolean ,
406+ dstIdPred : VertexId => Boolean ): Iterator [(VertexId , A )] = {
407+ val aggregates = new Array [A ](vertexAttrs.length)
408+ val bitset = new BitSet (vertexAttrs.length)
320409
321- override def hasNext : Boolean = {
322- pos >= 0 && pos < EdgePartition .this .size && srcIds(pos) == srcId
323- }
410+ var edge = new EdgeTriplet [VD , ED ]
411+ index.iterator.foreach { cluster =>
412+ val clusterSrcId = cluster._1
413+ val clusterPos = cluster._2
414+ val clusterLocalSrcId = localSrcIds(clusterPos)
415+ if (srcIdPred(clusterSrcId)) {
416+ var pos = clusterPos
417+ edge.srcId = clusterSrcId
418+ if (mapUsesSrcAttr) { edge.srcAttr = vertexAttrs(clusterLocalSrcId) }
419+ while (pos < size && localSrcIds(pos) == clusterLocalSrcId) {
420+ val localDstId = localDstIds(pos)
421+ val dstId = local2global(localDstId)
422+ if (dstIdPred(dstId)) {
423+ edge.dstId = dstId
424+ edge.attr = data(pos)
425+ if (mapUsesDstAttr) { edge.dstAttr = vertexAttrs(localDstId) }
324426
325- override def next (): Edge [ED ] = {
326- assert(srcIds(pos) == srcId)
327- edge.srcId = srcIds(pos)
328- edge.dstId = dstIds(pos)
329- edge.attr = data(pos)
330- pos += 1
331- edge
427+ mapFunc(edge).foreach { kv =>
428+ val globalId = kv._1
429+ val msg = kv._2
430+ val localId = if (globalId == clusterSrcId) clusterLocalSrcId else localDstId
431+ if (bitset.get(localId)) {
432+ aggregates(localId) = reduceFunc(aggregates(localId), msg)
433+ } else {
434+ aggregates(localId) = msg
435+ bitset.set(localId)
436+ }
437+ }
438+ }
439+ pos += 1
440+ }
441+ }
332442 }
443+
444+ bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) }
333445 }
334446}
0 commit comments