@@ -42,21 +42,20 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
4242
4343 test(" collectNeighborIds" ) {
4444 withSpark { sc =>
45- val chain = (0 until 100 ).map(x => (x, (x+ 1 )% 100 ) )
46- val rawEdges = sc.parallelize(chain, 3 ).map { case (s,d) => (s.toLong, d.toLong) }
47- val graph = Graph .fromEdgeTuples(rawEdges, 1.0 ).cache()
45+ val graph = getCycleGraph(sc, 100 )
4846 val nbrs = graph.collectNeighborIds(EdgeDirection .Either ).cache()
49- assert(nbrs.count === chain.size )
47+ assert(nbrs.count === 100 )
5048 assert(graph.numVertices === nbrs.count)
5149 nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2 ) }
52- nbrs.collect.foreach { case (vid, nbrs) =>
53- val s = nbrs.toSet
54- assert(s.contains((vid + 1 ) % 100 ))
55- assert(s.contains(if (vid > 0 ) vid - 1 else 99 ))
50+ nbrs.collect.foreach {
51+ case (vid, nbrs) =>
52+ val s = nbrs.toSet
53+ assert(s.contains((vid + 1 ) % 100 ))
54+ assert(s.contains(if (vid > 0 ) vid - 1 else 99 ))
5655 }
5756 }
5857 }
59-
58+
6059 test (" filter" ) {
6160 withSpark { sc =>
6261 val n = 5
@@ -80,4 +79,121 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
8079 }
8180 }
8281
82+ test(" collectEdgesCycleDirectionOut" ) {
83+ withSpark { sc =>
84+ val graph = getCycleGraph(sc, 100 )
85+ val edges = graph.collectEdges(EdgeDirection .Out ).cache()
86+ assert(edges.count == 100 )
87+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1 ) }
88+ edges.collect.foreach {
89+ case (vid, edges) =>
90+ val s = edges.toSet
91+ val edgeDstIds = s.map(e => e.dstId)
92+ assert(edgeDstIds.contains((vid + 1 ) % 100 ))
93+ }
94+ }
95+ }
96+
97+ test(" collectEdgesCycleDirectionIn" ) {
98+ withSpark { sc =>
99+ val graph = getCycleGraph(sc, 100 )
100+ val edges = graph.collectEdges(EdgeDirection .In ).cache()
101+ assert(edges.count == 100 )
102+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1 ) }
103+ edges.collect.foreach {
104+ case (vid, edges) =>
105+ val s = edges.toSet
106+ val edgeSrcIds = s.map(e => e.srcId)
107+ assert(edgeSrcIds.contains(if (vid > 0 ) vid - 1 else 99 ))
108+ }
109+ }
110+ }
111+
112+ test(" collectEdgesCycleDirectionEither" ) {
113+ withSpark { sc =>
114+ val graph = getCycleGraph(sc, 100 )
115+ val edges = graph.collectEdges(EdgeDirection .Either ).cache()
116+ assert(edges.count == 100 )
117+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 2 ) }
118+ edges.collect.foreach {
119+ case (vid, edges) =>
120+ val s = edges.toSet
121+ val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
122+ assert(edgeIds.contains((vid + 1 ) % 100 ))
123+ assert(edgeIds.contains(if (vid > 0 ) vid - 1 else 99 ))
124+ }
125+ }
126+ }
127+
128+ test(" collectEdgesChainDirectionOut" ) {
129+ withSpark { sc =>
130+ val graph = getChainGraph(sc, 50 )
131+ val edges = graph.collectEdges(EdgeDirection .Out ).cache()
132+ assert(edges.count == 49 )
133+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1 ) }
134+ edges.collect.foreach {
135+ case (vid, edges) =>
136+ val s = edges.toSet
137+ val edgeDstIds = s.map(e => e.dstId)
138+ assert(edgeDstIds.contains(vid + 1 ))
139+ }
140+ }
141+ }
142+
143+ test(" collectEdgesChainDirectionIn" ) {
144+ withSpark { sc =>
145+ val graph = getChainGraph(sc, 50 )
146+ val edges = graph.collectEdges(EdgeDirection .In ).cache()
147+ // We expect only 49 because collectEdges does not return vertices that do
148+ // not have any edges in the specified direction.
149+ assert(edges.count == 49 )
150+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1 ) }
151+ edges.collect.foreach {
152+ case (vid, edges) =>
153+ val s = edges.toSet
154+ val edgeDstIds = s.map(e => e.srcId)
155+ assert(edgeDstIds.contains((vid - 1 ) % 100 ))
156+ }
157+ }
158+ }
159+
160+ test(" collectEdgesChainDirectionEither" ) {
161+ withSpark { sc =>
162+ val graph = getChainGraph(sc, 50 )
163+ val edges = graph.collectEdges(EdgeDirection .Either ).cache()
164+ // We expect only 49 because collectEdges does not return vertices that do
165+ // not have any edges in the specified direction.
166+ assert(edges.count === 50 )
167+ edges.collect.foreach {
168+ case (vid, edges) => if (vid > 0 && vid < 49 ) assert(edges.size == 2 )
169+ else assert(edges.size == 1 )
170+ }
171+ edges.collect.foreach {
172+ case (vid, edges) =>
173+ val s = edges.toSet
174+ val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
175+ if (vid == 0 ) { assert(edgeIds.contains(1 )) }
176+ else if (vid == 49 ) { assert(edgeIds.contains(48 )) }
177+ else {
178+ assert(edgeIds.contains(vid + 1 ))
179+ assert(edgeIds.contains(vid - 1 ))
180+ }
181+ }
182+ }
183+ }
184+
185+ private def getCycleGraph (sc : SparkContext , numVertices : Int ): Graph [Double , Int ] = {
186+ val cycle = (0 until numVertices).map(x => (x, (x + 1 ) % numVertices))
187+ getGraphFromSeq(sc, cycle)
188+ }
189+
190+ private def getChainGraph (sc : SparkContext , numVertices : Int ): Graph [Double , Int ] = {
191+ val chain = (0 until numVertices - 1 ).map(x => (x, (x + 1 )))
192+ getGraphFromSeq(sc, chain)
193+ }
194+
195+ private def getGraphFromSeq (sc : SparkContext , seq : IndexedSeq [(Int , Int )]): Graph [Double , Int ] = {
196+ val rawEdges = sc.parallelize(seq, 3 ).map { case (s, d) => (s.toLong, d.toLong) }
197+ Graph .fromEdgeTuples(rawEdges, 1.0 ).cache()
198+ }
83199}
0 commit comments