Skip to content

Commit 0402be9

Browse files
ankurdaverxin
authored andcommitted
Internal cleanup for aggregateMessages
1. Add EdgeActiveness enum to represent activeness criteria more cleanly than using booleans. 2. Comments and whitespace. Author: Ankur Dave <[email protected]> Closes #3231 from ankurdave/aggregateMessages-followup and squashes the following commits: 3d485c3 [Ankur Dave] Internal cleanup for aggregateMessages
1 parent aa43a8d commit 0402be9

File tree

4 files changed

+69
-34
lines changed

4 files changed

+69
-34
lines changed

graphx/src/main/scala/org/apache/spark/graphx/Graph.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
207207
* }}}
208208
*
209209
*/
210-
def mapTriplets[ED2: ClassTag](
211-
map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
210+
def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
212211
mapTriplets((pid, iter) => iter.map(map), TripletFields.All)
213212
}
214213

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.graphx.impl;
19+
20+
/**
21+
* Criteria for filtering edges based on activeness. For internal use only.
22+
*/
23+
public enum EdgeActiveness {
24+
/** Neither the source vertex nor the destination vertex need be active. */
25+
Neither,
26+
/** The source vertex must be active. */
27+
SrcOnly,
28+
/** The destination vertex must be active. */
29+
DstOnly,
30+
/** Both vertices must be active. */
31+
Both,
32+
/** At least one vertex must be active. */
33+
Either
34+
}

graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class EdgePartition[
6464
activeSet: Option[VertexSet])
6565
extends Serializable {
6666

67+
/** No-arg constructor for serialization. */
6768
private def this() = this(null, null, null, null, null, null, null, null)
6869

6970
/** Return a new `EdgePartition` with the specified edge data. */
@@ -375,22 +376,15 @@ class EdgePartition[
375376
* @param sendMsg generates messages to neighboring vertices of an edge
376377
* @param mergeMsg the combiner applied to messages destined to the same vertex
377378
* @param tripletFields which triplet fields `sendMsg` uses
378-
* @param srcMustBeActive if true, edges will only be considered if their source vertex is in the
379-
* active set
380-
* @param dstMustBeActive if true, edges will only be considered if their destination vertex is in
381-
* the active set
382-
* @param maySatisfyEither if true, only one vertex need be in the active set for an edge to be
383-
* considered
379+
* @param activeness criteria for filtering edges based on activeness
384380
*
385381
* @return iterator aggregated messages keyed by the receiving vertex id
386382
*/
387383
def aggregateMessagesEdgeScan[A: ClassTag](
388384
sendMsg: EdgeContext[VD, ED, A] => Unit,
389385
mergeMsg: (A, A) => A,
390386
tripletFields: TripletFields,
391-
srcMustBeActive: Boolean,
392-
dstMustBeActive: Boolean,
393-
maySatisfyEither: Boolean): Iterator[(VertexId, A)] = {
387+
activeness: EdgeActiveness): Iterator[(VertexId, A)] = {
394388
val aggregates = new Array[A](vertexAttrs.length)
395389
val bitset = new BitSet(vertexAttrs.length)
396390

@@ -401,10 +395,13 @@ class EdgePartition[
401395
val srcId = local2global(localSrcId)
402396
val localDstId = localDstIds(i)
403397
val dstId = local2global(localDstId)
404-
val srcIsActive = !srcMustBeActive || isActive(srcId)
405-
val dstIsActive = !dstMustBeActive || isActive(dstId)
406398
val edgeIsActive =
407-
if (maySatisfyEither) srcIsActive || dstIsActive else srcIsActive && dstIsActive
399+
if (activeness == EdgeActiveness.Neither) true
400+
else if (activeness == EdgeActiveness.SrcOnly) isActive(srcId)
401+
else if (activeness == EdgeActiveness.DstOnly) isActive(dstId)
402+
else if (activeness == EdgeActiveness.Both) isActive(srcId) && isActive(dstId)
403+
else if (activeness == EdgeActiveness.Either) isActive(srcId) || isActive(dstId)
404+
else throw new Exception("unreachable")
408405
if (edgeIsActive) {
409406
val srcAttr = if (tripletFields.useSrc) vertexAttrs(localSrcId) else null.asInstanceOf[VD]
410407
val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD]
@@ -424,22 +421,15 @@ class EdgePartition[
424421
* @param sendMsg generates messages to neighboring vertices of an edge
425422
* @param mergeMsg the combiner applied to messages destined to the same vertex
426423
* @param tripletFields which triplet fields `sendMsg` uses
427-
* @param srcMustBeActive if true, edges will only be considered if their source vertex is in the
428-
* active set
429-
* @param dstMustBeActive if true, edges will only be considered if their destination vertex is in
430-
* the active set
431-
* @param maySatisfyEither if true, only one vertex need be in the active set for an edge to be
432-
* considered
424+
* @param activeness criteria for filtering edges based on activeness
433425
*
434426
* @return iterator aggregated messages keyed by the receiving vertex id
435427
*/
436428
def aggregateMessagesIndexScan[A: ClassTag](
437429
sendMsg: EdgeContext[VD, ED, A] => Unit,
438430
mergeMsg: (A, A) => A,
439431
tripletFields: TripletFields,
440-
srcMustBeActive: Boolean,
441-
dstMustBeActive: Boolean,
442-
maySatisfyEither: Boolean): Iterator[(VertexId, A)] = {
432+
activeness: EdgeActiveness): Iterator[(VertexId, A)] = {
443433
val aggregates = new Array[A](vertexAttrs.length)
444434
val bitset = new BitSet(vertexAttrs.length)
445435

@@ -448,18 +438,30 @@ class EdgePartition[
448438
val clusterSrcId = cluster._1
449439
val clusterPos = cluster._2
450440
val clusterLocalSrcId = localSrcIds(clusterPos)
451-
val srcIsActive = !srcMustBeActive || isActive(clusterSrcId)
452-
if (srcIsActive || maySatisfyEither) {
441+
442+
val scanCluster =
443+
if (activeness == EdgeActiveness.Neither) true
444+
else if (activeness == EdgeActiveness.SrcOnly) isActive(clusterSrcId)
445+
else if (activeness == EdgeActiveness.DstOnly) true
446+
else if (activeness == EdgeActiveness.Both) isActive(clusterSrcId)
447+
else if (activeness == EdgeActiveness.Either) true
448+
else throw new Exception("unreachable")
449+
450+
if (scanCluster) {
453451
var pos = clusterPos
454452
val srcAttr =
455453
if (tripletFields.useSrc) vertexAttrs(clusterLocalSrcId) else null.asInstanceOf[VD]
456454
ctx.setSrcOnly(clusterSrcId, clusterLocalSrcId, srcAttr)
457455
while (pos < size && localSrcIds(pos) == clusterLocalSrcId) {
458456
val localDstId = localDstIds(pos)
459457
val dstId = local2global(localDstId)
460-
val dstIsActive = !dstMustBeActive || isActive(dstId)
461458
val edgeIsActive =
462-
if (maySatisfyEither) srcIsActive || dstIsActive else srcIsActive && dstIsActive
459+
if (activeness == EdgeActiveness.Neither) true
460+
else if (activeness == EdgeActiveness.SrcOnly) true
461+
else if (activeness == EdgeActiveness.DstOnly) isActive(dstId)
462+
else if (activeness == EdgeActiveness.Both) isActive(dstId)
463+
else if (activeness == EdgeActiveness.Either) isActive(clusterSrcId) || isActive(dstId)
464+
else throw new Exception("unreachable")
463465
if (edgeIsActive) {
464466
val dstAttr =
465467
if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD]

graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,30 +218,30 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
218218
case Some(EdgeDirection.Both) =>
219219
if (activeFraction < 0.8) {
220220
edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
221-
true, true, false)
221+
EdgeActiveness.Both)
222222
} else {
223223
edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
224-
true, true, false)
224+
EdgeActiveness.Both)
225225
}
226226
case Some(EdgeDirection.Either) =>
227227
// TODO: Because we only have a clustered index on the source vertex ID, we can't filter
228228
// the index here. Instead we have to scan all edges and then do the filter.
229229
edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
230-
true, true, true)
230+
EdgeActiveness.Either)
231231
case Some(EdgeDirection.Out) =>
232232
if (activeFraction < 0.8) {
233233
edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
234-
true, false, false)
234+
EdgeActiveness.SrcOnly)
235235
} else {
236236
edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
237-
true, false, false)
237+
EdgeActiveness.SrcOnly)
238238
}
239239
case Some(EdgeDirection.In) =>
240240
edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
241-
false, true, false)
241+
EdgeActiveness.DstOnly)
242242
case _ => // None
243243
edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
244-
false, false, false)
244+
EdgeActiveness.Neither)
245245
}
246246
}).setName("GraphImpl.aggregateMessages - preAgg")
247247

0 commit comments

Comments
 (0)