diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index a570e4ed75fc..894cd97675b8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -223,8 +223,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { val vertices: RDD[(VertexId, Int)] = sc.parallelize(Array((1L, 1), (2L, 2))) val edges: RDD[Edge[Int]] = sc.parallelize(Array(Edge(1L, 2L, 0))) val graph = Graph(vertices, edges).reverse - val result = graph.mapReduceTriplets[Int](et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(result.collect().toSet === Set((1L, 2))) + val result = graph.aggregateMessages[Int](ctx => ctx.sendToDst(ctx.srcAttr), _ + _) + assert(result.collect.toSet === Set((1L, 2))) } } @@ -346,12 +346,13 @@ class GraphSuite extends FunSuite with LocalSparkContext { val n = 5 val reverseStar = starGraph(sc, n).reverse.cache() // outerJoinVertices changing type - val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { - (vid, a, bOpt) => bOpt.getOrElse(0) - } - val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets( - et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)), - (a: Int, b: Int) => a + b).collect().toSet + val reverseStarDegrees = + reverseStar.outerJoinVertices(reverseStar.outDegrees) { + (vid, a, bOpt) => bOpt.getOrElse(0) } + val neighborDegreeSums = reverseStarDegrees.aggregateMessages[Int]( ctx => { + ctx.sendToSrc(ctx.dstAttr) + ctx.sendToDst(ctx.srcAttr) + }, (a: Int, b: Int) => a + b).collect.toSet assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0))) // outerJoinVertices preserving type val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString } @@ -422,9 +423,9 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)), numEdgePartitions) val graph = Graph.fromEdgeTuples(edges, 1) - val neighborAttrSums = graph.mapReduceTriplets[Int]( - et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(neighborAttrSums.collect().toSet === Set((0: VertexId, n))) + val neighborAttrSums = graph.aggregateMessages[Int]( + ctx => ctx.sendToDst(ctx.srcAttr), _ + _) + assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n))) } finally { sc.stop() }