Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* }}}
*
*/
def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED]
def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2)
(implicit eq: VD =:= VD2 = null): Graph[VD2, ED]

/**
* Transforms each edge attribute in the graph using the map function. The map function is not
Expand Down Expand Up @@ -338,7 +339,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* }}}
*/
def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)])
(mapFunc: (VertexId, VD, Option[U]) => VD2)
(mapFunc: (VertexId, VD, Option[U]) => VD2)(implicit eq: VD =:= VD2 = null)
: Graph[VD2, ED]

/**
Expand Down
14 changes: 10 additions & 4 deletions graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
new GraphImpl(vertices.reverseRoutingTables(), replicatedVertexView.reverse())
}

override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = {
if (classTag[VD] equals classTag[VD2]) {
override def mapVertices[VD2: ClassTag]
(f: (VertexId, VD) => VD2)(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = {
// The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left
// null if not
if (eq != null) {
vertices.cache()
// The map preserves type, so we can use incremental replication
val newVerts = vertices.mapVertexPartitions(_.map(f)).cache()
Expand Down Expand Up @@ -228,8 +231,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (

override def outerJoinVertices[U: ClassTag, VD2: ClassTag]
(other: RDD[(VertexId, U)])
(updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = {
if (classTag[VD] equals classTag[VD2]) {
(updateF: (VertexId, VD, Option[U]) => VD2)
(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = {
// The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left
// null if not
if (eq != null) {
vertices.cache()
// updateF preserves type, so we can use incremental replication
val newVerts = vertices.leftJoin(other)(updateF).cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object LabelPropagation {
*
* @return a graph with vertex attributes containing the label of community affiliation
*/
def run[ED: ClassTag](graph: Graph[_, ED], maxSteps: Int): Graph[VertexId, ED] = {
def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = {
val lpaGraph = graph.mapVertices { case (vid, _) => vid }
def sendMessage(e: EdgeTriplet[VertexId, ED]) = {
Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object ShortestPaths {
* @return a graph where each vertex attribute is a map containing the shortest-path distance to
* each reachable landmark vertex.
*/
def run[ED: ClassTag](graph: Graph[_, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
val spGraph = graph.mapVertices { (vid, attr) =>
if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap()
}
Expand Down
25 changes: 25 additions & 0 deletions graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,31 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
}

test("mapVertices changing type with same erased type") {
withSpark { sc =>
val vertices = sc.parallelize(Array[(Long, Option[java.lang.Integer])](
(1L, Some(1)),
(2L, Some(2)),
(3L, Some(3))
))
val edges = sc.parallelize(Array(
Edge(1L, 2L, 0),
Edge(2L, 3L, 0),
Edge(3L, 1L, 0)
))
val graph0 = Graph(vertices, edges)
// Trigger initial vertex replication
graph0.triplets.foreach(x => {})
// Change type of replicated vertices, but preserve erased type
val graph1 = graph0.mapVertices {
case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double)
}
// Access replicated vertices, exposing the erased type
val graph2 = graph1.mapTriplets(t => t.srcAttr.get)
assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0))
}
}

test("mapEdges") {
withSpark { sc =>
val n = 3
Expand Down