Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 196 additions & 78 deletions graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveVector}
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap

/** Constructs an EdgePartition from scratch. */
private[graphx]
class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag](
size: Int = 64) {
Expand All @@ -38,19 +39,77 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
def toEdgePartition: EdgePartition[ED, VD] = {
val edgeArray = edges.trim().array
Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering)
val srcIds = new Array[VertexId](edgeArray.size)
val dstIds = new Array[VertexId](edgeArray.size)
val localSrcIds = new Array[Int](edgeArray.size)
val localDstIds = new Array[Int](edgeArray.size)
val data = new Array[ED](edgeArray.size)
val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int]
val global2local = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int]
val local2global = new PrimitiveVector[VertexId]
var vertexAttrs = Array.empty[VD]
// Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
// adding them to the index. Also populate a map from vertex id to a sequential local offset.
if (edgeArray.length > 0) {
index.update(edgeArray(0).srcId, 0)
var currSrcId: VertexId = edgeArray(0).srcId
var currLocalId = -1
var i = 0
while (i < edgeArray.size) {
val srcId = edgeArray(i).srcId
val dstId = edgeArray(i).dstId
localSrcIds(i) = global2local.changeValue(srcId,
{ currLocalId += 1; local2global += srcId; currLocalId }, identity)
localDstIds(i) = global2local.changeValue(dstId,
{ currLocalId += 1; local2global += dstId; currLocalId }, identity)
data(i) = edgeArray(i).attr
if (srcId != currSrcId) {
currSrcId = srcId
index.update(currSrcId, i)
}

i += 1
}
vertexAttrs = new Array[VD](currLocalId + 1)
}
new EdgePartition(
localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs)
}
}

/**
* Constructs an EdgePartition from an existing EdgePartition with the same vertex set. This enables
* reuse of the local vertex ids. Intended for internal use in EdgePartition only.
*/
private[impl]
class ExistingEdgePartitionBuilder[
@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag](
global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int],
local2global: Array[VertexId],
vertexAttrs: Array[VD],
activeSet: Option[VertexSet],
size: Int = 64) {
var edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maek this private[this] val instead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also do it for the one in EdgePartitionBuilder


/** Add a new edge to the partition. */
def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) {
edges += EdgeWithLocalIds(src, dst, localSrc, localDst, d)
}

def toEdgePartition: EdgePartition[ED, VD] = {
val edgeArray = edges.trim().array
Sorting.quickSort(edgeArray)(EdgeWithLocalIds.lexicographicOrdering)
val localSrcIds = new Array[Int](edgeArray.size)
val localDstIds = new Array[Int](edgeArray.size)
val data = new Array[ED](edgeArray.size)
val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int]
// Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
// adding them to the index
if (edgeArray.length > 0) {
index.update(srcIds(0), 0)
var currSrcId: VertexId = srcIds(0)
index.update(edgeArray(0).srcId, 0)
var currSrcId: VertexId = edgeArray(0).srcId
var i = 0
while (i < edgeArray.size) {
srcIds(i) = edgeArray(i).srcId
dstIds(i) = edgeArray(i).dstId
localSrcIds(i) = edgeArray(i).localSrcId
localDstIds(i) = edgeArray(i).localDstId
data(i) = edgeArray(i).attr
if (edgeArray(i).srcId != currSrcId) {
currSrcId = edgeArray(i).srcId
Expand All @@ -60,13 +119,24 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
}
}

// Create and populate a VertexPartition with vids from the edges, but no attributes
val vidsIter = srcIds.iterator ++ dstIds.iterator
val vertexIds = new OpenHashSet[VertexId]
vidsIter.foreach(vid => vertexIds.add(vid))
val vertices = new VertexPartition(
vertexIds, new Array[VD](vertexIds.capacity), vertexIds.getBitSet)
new EdgePartition(
localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet)
}
}

new EdgePartition(srcIds, dstIds, data, index, vertices)
private[impl] case class EdgeWithLocalIds[@specialized ED](
srcId: VertexId, dstId: VertexId, localSrcId: Int, localDstId: Int, attr: ED)

private[impl] object EdgeWithLocalIds {
implicit def lexicographicOrdering[ED] = new Ordering[EdgeWithLocalIds[ED]] {
override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = {
if (a.srcId == b.srcId) {
if (a.dstId == b.dstId) 0
else if (a.dstId < b.dstId) -1
else 1
} else if (a.srcId < b.srcId) -1
else 1
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,45 +40,18 @@ class EdgeTripletIterator[VD: ClassTag, ED: ClassTag](

override def next() = {
val triplet = new EdgeTriplet[VD, ED]
triplet.srcId = edgePartition.srcIds(pos)
val localSrcId = edgePartition.localSrcIds(pos)
val localDstId = edgePartition.localDstIds(pos)
triplet.srcId = edgePartition.local2global(localSrcId)
triplet.dstId = edgePartition.local2global(localDstId)
if (includeSrc) {
triplet.srcAttr = edgePartition.vertices(triplet.srcId)
triplet.srcAttr = edgePartition.vertexAttrs(localSrcId)
}
triplet.dstId = edgePartition.dstIds(pos)
if (includeDst) {
triplet.dstAttr = edgePartition.vertices(triplet.dstId)
triplet.dstAttr = edgePartition.vertexAttrs(localDstId)
}
triplet.attr = edgePartition.data(pos)
pos += 1
triplet
}
}

/**
* An Iterator type for internal use that reuses EdgeTriplet objects. This could be an anonymous
* class in EdgePartition.upgradeIterator, but we name it here explicitly so it is easier to debug /
* profile.
*/
private[impl]
class ReusingEdgeTripletIterator[VD: ClassTag, ED: ClassTag](
val edgeIter: Iterator[Edge[ED]],
val edgePartition: EdgePartition[ED, VD],
val includeSrc: Boolean,
val includeDst: Boolean)
extends Iterator[EdgeTriplet[VD, ED]] {

private val triplet = new EdgeTriplet[VD, ED]

override def hasNext = edgeIter.hasNext

override def next() = {
triplet.set(edgeIter.next())
if (includeSrc) {
triplet.srcAttr = edgePartition.vertices(triplet.srcId)
}
if (includeDst) {
triplet.dstAttr = edgePartition.vertices(triplet.dstId)
}
triplet
}
}
46 changes: 25 additions & 21 deletions graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.HashPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.storage.StorageLevel

import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl._
import org.apache.spark.graphx.util.BytecodeUtils
Expand Down Expand Up @@ -193,37 +192,44 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
case (pid, edgePartition) =>
// Choose scan method
val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
val edgeIter = activeDirectionOpt match {
activeDirectionOpt match {
case Some(EdgeDirection.Both) =>
if (activeFraction < 0.8) {
edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId))
.filter(e => edgePartition.isActive(e.dstId))
edgePartition.mapReduceTripletsWithIndex(
mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr,
srcId => edgePartition.isActive(srcId),
dstId => edgePartition.isActive(dstId))
} else {
edgePartition.iterator.filter(e =>
edgePartition.isActive(e.srcId) && edgePartition.isActive(e.dstId))
edgePartition.mapReduceTriplets(
mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr,
(srcId, dstId) => edgePartition.isActive(srcId) && edgePartition.isActive(dstId))
}
case Some(EdgeDirection.Either) =>
// TODO: Because we only have a clustered index on the source vertex ID, we can't filter
// the index here. Instead we have to scan all edges and then do the filter.
edgePartition.iterator.filter(e =>
edgePartition.isActive(e.srcId) || edgePartition.isActive(e.dstId))
edgePartition.mapReduceTriplets(
mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr,
(srcId, dstId) => edgePartition.isActive(srcId) || edgePartition.isActive(dstId))
case Some(EdgeDirection.Out) =>
if (activeFraction < 0.8) {
edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId))
edgePartition.mapReduceTripletsWithIndex(
mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr,
srcId => edgePartition.isActive(srcId),
dstId => true)
} else {
edgePartition.iterator.filter(e => edgePartition.isActive(e.srcId))
edgePartition.mapReduceTriplets(
mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr,
(srcId, dstId) => edgePartition.isActive(srcId))
}
case Some(EdgeDirection.In) =>
edgePartition.iterator.filter(e => edgePartition.isActive(e.dstId))
edgePartition.mapReduceTriplets(
mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr,
(srcId, dstId) => edgePartition.isActive(dstId))
case _ => // None
edgePartition.iterator
edgePartition.mapReduceTriplets(
mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr,
(srcId, dstId) => true)
}

// Scan edges and run the map function
val mapOutputs = edgePartition.upgradeIterator(edgeIter, mapUsesSrcAttr, mapUsesDstAttr)
.flatMap(mapFunc(_))
// Note: This doesn't allow users to send messages to arbitrary vertices.
edgePartition.vertices.aggregateUsingIndex(mapOutputs, reduceFunc).iterator
}).setName("GraphImpl.mapReduceTriplets - preAgg")

// do the final reduction reusing the index map
Expand Down Expand Up @@ -306,9 +312,7 @@ object GraphImpl {
vertices: VertexRDD[VD],
edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = {
// Convert the vertex partitions in edges to the correct type
val newEdges = edges.mapEdgePartitions(
(pid, part) => part.withVertices(part.vertices.map(
(vid, attr) => null.asInstanceOf[VD])))
val newEdges = edges.mapEdgePartitions((pid, part) => part.clearVertices[VD])
GraphImpl.fromExistingRDDs(vertices, newEdges)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,9 @@ object RoutingTablePartition {
// Determine which positions each vertex id appears in using a map where the low 2 bits
// represent src and dst
val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, Byte]
edgePartition.srcIds.iterator.foreach { srcId =>
map.changeValue(srcId, 0x1, (b: Byte) => (b | 0x1).toByte)
}
edgePartition.dstIds.iterator.foreach { dstId =>
map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
edgePartition.iterator.foreach { e =>
map.changeValue(e.srcId, 0x1, (b: Byte) => (b | 0x1).toByte)
map.changeValue(e.dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
}
map.iterator.map { vidAndPosition =>
val vid = vidAndPosition._1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
// Each vertex should be replicated to at most 2 * sqrt(p) partitions
val partitionSets = partitionedGraph.edges.partitionsRDD.mapPartitions { iter =>
val part = iter.next()._2
Iterator((part.srcIds ++ part.dstIds).toSet)
Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet)
}.collect
if (!verts.forall(id => partitionSets.count(_.contains(id)) <= bound)) {
val numFailures = verts.count(id => partitionSets.count(_.contains(id)) > bound)
Expand All @@ -130,7 +130,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
// This should not be true for the default hash partitioning
val partitionSetsUnpartitioned = graph.edges.partitionsRDD.mapPartitions { iter =>
val part = iter.next()._2
Iterator((part.srcIds ++ part.dstIds).toSet)
Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet)
}.collect
assert(verts.exists(id => partitionSetsUnpartitioned.count(_.contains(id)) > bound))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,29 +82,6 @@ class EdgePartitionSuite extends FunSuite {
assert(edgePartition.groupEdges(_ + _).iterator.map(_.copy()).toList === groupedEdges)
}

test("upgradeIterator") {
val edges = List((0, 1, 0), (1, 0, 0))
val verts = List((0L, 1), (1L, 2))
val part = makeEdgePartition(edges).updateVertices(verts.iterator)
assert(part.upgradeIterator(part.iterator).map(_.toTuple).toList ===
part.tripletIterator().toList.map(_.toTuple))
}

test("indexIterator") {
val edgesFrom0 = List(Edge(0, 1, 0))
val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0))
val sortedEdges = edgesFrom0 ++ edgesFrom1
val builder = new EdgePartitionBuilder[Int, Nothing]
for (e <- Random.shuffle(sortedEdges)) {
builder.add(e.srcId, e.dstId, e.attr)
}

val edgePartition = builder.toEdgePartition
assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges)
assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0)
assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1)
}

test("innerJoin") {
val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0))
Expand All @@ -126,7 +103,7 @@ class EdgePartitionSuite extends FunSuite {
}

test("serialization") {
val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
val aList = List((0, 1, 1), (1, 0, 2), (1, 2, 3), (5, 4, 4), (5, 5, 5))
val a: EdgePartition[Int, Int] = makeEdgePartition(aList)
val javaSer = new JavaSerializer(new SparkConf())
val conf = new SparkConf()
Expand All @@ -135,11 +112,8 @@ class EdgePartitionSuite extends FunSuite {

for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
assert(aSer.srcIds.toList === a.srcIds.toList)
assert(aSer.dstIds.toList === a.dstIds.toList)
assert(aSer.data.toList === a.data.toList)
assert(aSer.tripletIterator().toList === a.tripletIterator().toList)
assert(aSer.index != null)
assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet)
}
}
}