Skip to content

Commit 1f4c7f7

Browse files
semihsalihoglurxin
authored andcommitted
Graph primitives2
Hi guys, I'm following Joey and Ankur's suggestions to add collectEdges and pickRandomVertex. I'm also adding the tests for collectEdges and refactoring one method getCycleGraph in GraphOpsSuite.scala. Thank you, semih Author: Semih Salihoglu <[email protected]> Closes apache#580 from semihsalihoglu/GraphPrimitives2 and squashes the following commits: 937d3ec [Semih Salihoglu] - Fixed the scalastyle errors. a69a152 [Semih Salihoglu] - Adding collectEdges and pickRandomVertices. - Adding tests for collectEdges. - Refactoring a getCycle utility function for GraphOpsSuite.scala. 41265a6 [Semih Salihoglu] - Adding collectEdges and pickRandomVertex. - Adding tests for collectEdges. - Recycling a getCycle utility test file.
1 parent a4f4fbc commit 1f4c7f7

File tree

2 files changed

+183
-10
lines changed

2 files changed

+183
-10
lines changed

graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
package org.apache.spark.graphx
1919

2020
import scala.reflect.ClassTag
21-
2221
import org.apache.spark.SparkContext._
2322
import org.apache.spark.SparkException
2423
import org.apache.spark.graphx.lib._
2524
import org.apache.spark.rdd.RDD
25+
import scala.util.Random
2626

2727
/**
2828
* Contains additional functionality for [[Graph]]. All operations are expressed in terms of the
@@ -137,6 +137,42 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
137137
}
138138
} // end of collectNeighbor
139139

140+
/**
141+
* Returns an RDD that contains for each vertex v its local edges,
142+
* i.e., the edges that are incident on v, in the user-specified direction.
143+
* Warning: note that singleton vertices, those with no edges in the given
144+
* direction will not be part of the return value.
145+
*
146+
* @note This function could be highly inefficient on power-law
147+
* graphs where high degree vertices may force a large amount of
148+
* information to be collected to a single location.
149+
*
150+
* @param edgeDirection the direction along which to collect
151+
* the local edges of vertices
152+
*
153+
* @return the local edges for each vertex
154+
*/
155+
def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = {
156+
edgeDirection match {
157+
case EdgeDirection.Either =>
158+
graph.mapReduceTriplets[Array[Edge[ED]]](
159+
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))),
160+
(edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
161+
(a, b) => a ++ b)
162+
case EdgeDirection.In =>
163+
graph.mapReduceTriplets[Array[Edge[ED]]](
164+
edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
165+
(a, b) => a ++ b)
166+
case EdgeDirection.Out =>
167+
graph.mapReduceTriplets[Array[Edge[ED]]](
168+
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
169+
(a, b) => a ++ b)
170+
case EdgeDirection.Both =>
171+
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
172+
"EdgeDirection.Either instead.")
173+
}
174+
}
175+
140176
/**
141177
* Join the vertices with an RDD and then apply a function from the
142178
* the vertex and RDD entry to a new vertex value. The input table
@@ -209,6 +245,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
209245
graph.mask(preprocess(graph).subgraph(epred, vpred))
210246
}
211247

248+
/**
249+
* Picks a random vertex from the graph and returns its ID.
250+
*/
251+
def pickRandomVertex(): VertexId = {
252+
val probability = 50 / graph.numVertices
253+
var found = false
254+
var retVal: VertexId = null.asInstanceOf[VertexId]
255+
while (!found) {
256+
val selectedVertices = graph.vertices.flatMap { vidVvals =>
257+
if (Random.nextDouble() < probability) { Some(vidVvals._1) }
258+
else { None }
259+
}
260+
if (selectedVertices.count > 1) {
261+
found = true
262+
val collectedVertices = selectedVertices.collect()
263+
retVal = collectedVertices(Random.nextInt(collectedVertices.size))
264+
}
265+
}
266+
retVal
267+
}
268+
212269
/**
213270
* Execute a Pregel-like iterative vertex-parallel abstraction. The
214271
* user-defined vertex-program `vprog` is executed in parallel on

graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala

Lines changed: 125 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,20 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
4242

4343
test("collectNeighborIds") {
4444
withSpark { sc =>
45-
val chain = (0 until 100).map(x => (x, (x+1)%100) )
46-
val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) }
47-
val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
45+
val graph = getCycleGraph(sc, 100)
4846
val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache()
49-
assert(nbrs.count === chain.size)
47+
assert(nbrs.count === 100)
5048
assert(graph.numVertices === nbrs.count)
5149
nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) }
52-
nbrs.collect.foreach { case (vid, nbrs) =>
53-
val s = nbrs.toSet
54-
assert(s.contains((vid + 1) % 100))
55-
assert(s.contains(if (vid > 0) vid - 1 else 99 ))
50+
nbrs.collect.foreach {
51+
case (vid, nbrs) =>
52+
val s = nbrs.toSet
53+
assert(s.contains((vid + 1) % 100))
54+
assert(s.contains(if (vid > 0) vid - 1 else 99))
5655
}
5756
}
5857
}
59-
58+
6059
test ("filter") {
6160
withSpark { sc =>
6261
val n = 5
@@ -80,4 +79,121 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
8079
}
8180
}
8281

82+
test("collectEdgesCycleDirectionOut") {
83+
withSpark { sc =>
84+
val graph = getCycleGraph(sc, 100)
85+
val edges = graph.collectEdges(EdgeDirection.Out).cache()
86+
assert(edges.count == 100)
87+
edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
88+
edges.collect.foreach {
89+
case (vid, edges) =>
90+
val s = edges.toSet
91+
val edgeDstIds = s.map(e => e.dstId)
92+
assert(edgeDstIds.contains((vid + 1) % 100))
93+
}
94+
}
95+
}
96+
97+
test("collectEdgesCycleDirectionIn") {
98+
withSpark { sc =>
99+
val graph = getCycleGraph(sc, 100)
100+
val edges = graph.collectEdges(EdgeDirection.In).cache()
101+
assert(edges.count == 100)
102+
edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
103+
edges.collect.foreach {
104+
case (vid, edges) =>
105+
val s = edges.toSet
106+
val edgeSrcIds = s.map(e => e.srcId)
107+
assert(edgeSrcIds.contains(if (vid > 0) vid - 1 else 99))
108+
}
109+
}
110+
}
111+
112+
test("collectEdgesCycleDirectionEither") {
113+
withSpark { sc =>
114+
val graph = getCycleGraph(sc, 100)
115+
val edges = graph.collectEdges(EdgeDirection.Either).cache()
116+
assert(edges.count == 100)
117+
edges.collect.foreach { case (vid, edges) => assert(edges.size == 2) }
118+
edges.collect.foreach {
119+
case (vid, edges) =>
120+
val s = edges.toSet
121+
val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
122+
assert(edgeIds.contains((vid + 1) % 100))
123+
assert(edgeIds.contains(if (vid > 0) vid - 1 else 99))
124+
}
125+
}
126+
}
127+
128+
test("collectEdgesChainDirectionOut") {
129+
withSpark { sc =>
130+
val graph = getChainGraph(sc, 50)
131+
val edges = graph.collectEdges(EdgeDirection.Out).cache()
132+
assert(edges.count == 49)
133+
edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
134+
edges.collect.foreach {
135+
case (vid, edges) =>
136+
val s = edges.toSet
137+
val edgeDstIds = s.map(e => e.dstId)
138+
assert(edgeDstIds.contains(vid + 1))
139+
}
140+
}
141+
}
142+
143+
test("collectEdgesChainDirectionIn") {
144+
withSpark { sc =>
145+
val graph = getChainGraph(sc, 50)
146+
val edges = graph.collectEdges(EdgeDirection.In).cache()
147+
// We expect only 49 because collectEdges does not return vertices that do
148+
// not have any edges in the specified direction.
149+
assert(edges.count == 49)
150+
edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
151+
edges.collect.foreach {
152+
case (vid, edges) =>
153+
val s = edges.toSet
154+
val edgeDstIds = s.map(e => e.srcId)
155+
assert(edgeDstIds.contains((vid - 1) % 100))
156+
}
157+
}
158+
}
159+
160+
test("collectEdgesChainDirectionEither") {
161+
withSpark { sc =>
162+
val graph = getChainGraph(sc, 50)
163+
val edges = graph.collectEdges(EdgeDirection.Either).cache()
164+
// We expect only 49 because collectEdges does not return vertices that do
165+
// not have any edges in the specified direction.
166+
assert(edges.count === 50)
167+
edges.collect.foreach {
168+
case (vid, edges) => if (vid > 0 && vid < 49) assert(edges.size == 2)
169+
else assert(edges.size == 1)
170+
}
171+
edges.collect.foreach {
172+
case (vid, edges) =>
173+
val s = edges.toSet
174+
val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
175+
if (vid == 0) { assert(edgeIds.contains(1)) }
176+
else if (vid == 49) { assert(edgeIds.contains(48)) }
177+
else {
178+
assert(edgeIds.contains(vid + 1))
179+
assert(edgeIds.contains(vid - 1))
180+
}
181+
}
182+
}
183+
}
184+
185+
private def getCycleGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = {
186+
val cycle = (0 until numVertices).map(x => (x, (x + 1) % numVertices))
187+
getGraphFromSeq(sc, cycle)
188+
}
189+
190+
private def getChainGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = {
191+
val chain = (0 until numVertices - 1).map(x => (x, (x + 1)))
192+
getGraphFromSeq(sc, chain)
193+
}
194+
195+
private def getGraphFromSeq(sc: SparkContext, seq: IndexedSeq[(Int, Int)]): Graph[Double, Int] = {
196+
val rawEdges = sc.parallelize(seq, 3).map { case (s, d) => (s.toLong, d.toLong) }
197+
Graph.fromEdgeTuples(rawEdges, 1.0).cache()
198+
}
83199
}

0 commit comments

Comments
 (0)