Skip to content

Commit d7ebff0

Browse files
committed
Merge pull request #1 from ankurdave/add_project_to_graph
Merge current master and reimplement Graph.mask using innerJoin
2 parents cb20175 + 0f137e8 commit d7ebff0

38 files changed

+2448
-1985
lines changed

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
245245
if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) {
246246
throw new SparkException("Default partitioner cannot partition array keys.")
247247
}
248-
new ShuffledRDD[K, V, (K, V)](self, partitioner)
248+
if (self.partitioner == partitioner) self else new ShuffledRDD[K, V, (K, V)](self, partitioner)
249249
}
250250

251251
/**

core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import java.io.{ObjectOutputStream, IOException}
2222

2323
private[spark] class ZippedPartitionsPartition(
2424
idx: Int,
25-
@transient rdds: Seq[RDD[_]])
25+
@transient rdds: Seq[RDD[_]],
26+
@transient val preferredLocations: Seq[String])
2627
extends Partition {
2728

2829
override val index: Int = idx
@@ -47,27 +48,21 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
4748
if (preservesPartitioning) firstParent[Any].partitioner else None
4849

4950
override def getPartitions: Array[Partition] = {
50-
val sizes = rdds.map(x => x.partitions.size)
51-
if (!sizes.forall(x => x == sizes(0))) {
51+
val numParts = rdds.head.partitions.size
52+
if (!rdds.forall(rdd => rdd.partitions.size == numParts)) {
5253
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
5354
}
54-
val array = new Array[Partition](sizes(0))
55-
for (i <- 0 until sizes(0)) {
56-
array(i) = new ZippedPartitionsPartition(i, rdds)
55+
Array.tabulate[Partition](numParts) { i =>
56+
val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i)))
57+
// Check whether there are any hosts that match all RDDs; otherwise return the union
58+
val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
59+
val locs = if (!exactMatchLocations.isEmpty) exactMatchLocations else prefs.flatten.distinct
60+
new ZippedPartitionsPartition(i, rdds, locs)
5761
}
58-
array
5962
}
6063

6164
override def getPreferredLocations(s: Partition): Seq[String] = {
62-
val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions
63-
val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) }
64-
// Check whether there are any hosts that match all RDDs; otherwise return the union
65-
val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
66-
if (!exactMatchLocations.isEmpty) {
67-
exactMatchLocations
68-
} else {
69-
prefs.flatten.distinct
70-
}
65+
s.asInstanceOf[ZippedPartitionsPartition].preferredLocations
7166
}
7267

7368
override def clearDependencies() {

core/src/main/scala/org/apache/spark/util/collection/BitSet.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ class BitSet(numBits: Int) {
102102
words(index >> 6) |= bitmask // div by 64 and mask
103103
}
104104

105+
def unset(index: Int) {
106+
val bitmask = 1L << (index & 0x3f) // mod 64 and shift
107+
words(index >> 6) &= ~bitmask // div by 64 and mask
108+
}
105109

106110
/**
107111
* Return the value of the bit with the specified index. The value is true if the bit with

core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
158158
/** Return the value at the specified position. */
159159
def getValue(pos: Int): T = _data(pos)
160160

161-
def iterator() = new Iterator[T] {
161+
def iterator = new Iterator[T] {
162162
var pos = nextPos(0)
163163
override def hasNext: Boolean = pos != INVALID_POS
164164
override def next(): T = {
@@ -249,8 +249,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
249249
* in the lower bits, similar to java.util.HashMap
250250
*/
251251
private def hashcode(h: Int): Int = {
252-
val r = h ^ (h >>> 20) ^ (h >>> 12)
253-
r ^ (r >>> 7) ^ (r >>> 4)
252+
it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
254253
}
255254

256255
private def nextPowerOf2(n: Int): Int = {

core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
7171

7272

7373
/** Set the value for a key */
74-
def setMerge(k: K, v: V, mergeF: (V,V) => V) {
74+
def setMerge(k: K, v: V, mergeF: (V, V) => V) {
7575
val pos = keySet.addWithoutResize(k)
7676
val ind = pos & OpenHashSet.POSITION_MASK
7777
if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { // if first add

0 commit comments

Comments
 (0)