Skip to content

Commit 29e39e1

Browse files
committed
[SPARK-3266] Use intermediate abstract classes to fix type erasure issues in Java APIs
This PR addresses a Scala compiler bug ([SI-8905](https://issues.scala-lang.org/browse/SI-8905)) that was breaking some of the Spark Java APIs. In a nutshell, it seems that methods whose implementations are inherited from generic traits sometimes have their type parameters erased to Object. This was causing methods like `DoubleRDD.min()` to throw confusing NoSuchMethodErrors at runtime. The fix implemented here is to introduce an intermediate layer of abstract classes and inherit from those instead of directly extends the `Java*Like` traits. This should not break binary compatibility. I also improved the test coverage of the Java API, adding several new tests for methods that failed at runtime due to this bug. Author: Josh Rosen <[email protected]> Closes #5050 from JoshRosen/javardd-si-8905-fix and squashes the following commits: 2feb068 [Josh Rosen] Use intermediate abstract classes to work around SPARK-3266 d5f3e5d [Josh Rosen] Add failing regression tests for SPARK-3266 (cherry picked from commit 0f673c2) Signed-off-by: Josh Rosen <[email protected]>
1 parent 95f8d1c commit 29e39e1

File tree

8 files changed

+152
-5
lines changed

8 files changed

+152
-5
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ import org.apache.spark.storage.StorageLevel
3232
import org.apache.spark.util.StatCounter
3333
import org.apache.spark.util.Utils
3434

35-
class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, JavaDoubleRDD] {
35+
class JavaDoubleRDD(val srdd: RDD[scala.Double])
36+
extends AbstractJavaRDDLike[JDouble, JavaDoubleRDD] {
3637

3738
override val classTag: ClassTag[JDouble] = implicitly[ClassTag[JDouble]]
3839

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import org.apache.spark.util.Utils
4444

4545
class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
4646
(implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V])
47-
extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
47+
extends AbstractJavaRDDLike[(K, V), JavaPairRDD[K, V]] {
4848

4949
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
5050

core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel
3030
import org.apache.spark.util.Utils
3131

3232
class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
33-
extends JavaRDDLike[T, JavaRDD[T]] {
33+
extends AbstractJavaRDDLike[T, JavaRDD[T]] {
3434

3535
override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
3636

core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ import org.apache.spark.rdd.RDD
3838
import org.apache.spark.storage.StorageLevel
3939
import org.apache.spark.util.Utils
4040

41+
/**
42+
* As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations
43+
* of JavaRDDLike should extend this dummy abstract class instead of directly inheriting
44+
* from the trait. See SPARK-3266 for additional details.
45+
*/
46+
private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This]]
47+
extends JavaRDDLike[T, This]
48+
4149
/**
4250
* Defines operations common to several Java RDD implementations.
4351
* Note that this trait is not intended to be implemented by user code.

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,22 @@ public void call(String s) throws IOException {
267267
Assert.assertEquals(2, accum.value().intValue());
268268
}
269269

270+
@Test
271+
public void foreachPartition() {
272+
final Accumulator<Integer> accum = sc.accumulator(0);
273+
JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello", "World"));
274+
rdd.foreachPartition(new VoidFunction<Iterator<String>>() {
275+
@Override
276+
public void call(Iterator<String> iter) throws IOException {
277+
while (iter.hasNext()) {
278+
iter.next();
279+
accum.add(1);
280+
}
281+
}
282+
});
283+
Assert.assertEquals(2, accum.value().intValue());
284+
}
285+
270286
@Test
271287
public void toLocalIterator() {
272288
List<Integer> correct = Arrays.asList(1, 2, 3, 4);
@@ -657,6 +673,13 @@ public Boolean call(Integer i) {
657673
}).isEmpty());
658674
}
659675

676+
@Test
677+
public void toArray() {
678+
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3));
679+
List<Integer> list = rdd.toArray();
680+
Assert.assertEquals(Arrays.asList(1, 2, 3), list);
681+
}
682+
660683
@Test
661684
public void cartesian() {
662685
JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
@@ -710,6 +733,80 @@ public void javaDoubleRDDHistoGram() {
710733
Assert.assertArrayEquals(expected_counts, histogram);
711734
}
712735

736+
private static class DoubleComparator implements Comparator<Double>, Serializable {
737+
public int compare(Double o1, Double o2) {
738+
return o1.compareTo(o2);
739+
}
740+
}
741+
742+
@Test
743+
public void max() {
744+
JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
745+
double max = rdd.max(new DoubleComparator());
746+
Assert.assertEquals(4.0, max, 0.001);
747+
}
748+
749+
@Test
750+
public void min() {
751+
JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
752+
double max = rdd.min(new DoubleComparator());
753+
Assert.assertEquals(1.0, max, 0.001);
754+
}
755+
756+
@Test
757+
public void takeOrdered() {
758+
JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
759+
Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2, new DoubleComparator()));
760+
Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2));
761+
}
762+
763+
@Test
764+
public void top() {
765+
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
766+
List<Integer> top2 = rdd.top(2);
767+
Assert.assertEquals(Arrays.asList(4, 3), top2);
768+
}
769+
770+
private static class AddInts implements Function2<Integer, Integer, Integer> {
771+
@Override
772+
public Integer call(Integer a, Integer b) {
773+
return a + b;
774+
}
775+
}
776+
777+
@Test
778+
public void reduce() {
779+
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
780+
int sum = rdd.reduce(new AddInts());
781+
Assert.assertEquals(10, sum);
782+
}
783+
784+
@Test
785+
public void reduceOnJavaDoubleRDD() {
786+
JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
787+
double sum = rdd.reduce(new Function2<Double, Double, Double>() {
788+
@Override
789+
public Double call(Double v1, Double v2) throws Exception {
790+
return v1 + v2;
791+
}
792+
});
793+
Assert.assertEquals(10.0, sum, 0.001);
794+
}
795+
796+
@Test
797+
public void fold() {
798+
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
799+
int sum = rdd.fold(0, new AddInts());
800+
Assert.assertEquals(10, sum);
801+
}
802+
803+
@Test
804+
public void aggregate() {
805+
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
806+
int sum = rdd.aggregate(0, new AddInts(), new AddInts());
807+
Assert.assertEquals(10, sum);
808+
}
809+
713810
@Test
714811
public void map() {
715812
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
@@ -826,6 +923,25 @@ public Iterable<Integer> call(Iterator<Integer> iter) {
826923
Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
827924
}
828925

926+
927+
@Test
928+
public void mapPartitionsWithIndex() {
929+
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
930+
JavaRDD<Integer> partitionSums = rdd.mapPartitionsWithIndex(
931+
new Function2<Integer, Iterator<Integer>, Iterator<Integer>>() {
932+
@Override
933+
public Iterator<Integer> call(Integer index, Iterator<Integer> iter) throws Exception {
934+
int sum = 0;
935+
while (iter.hasNext()) {
936+
sum += iter.next();
937+
}
938+
return Collections.singletonList(sum).iterator();
939+
}
940+
}, false);
941+
Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
942+
}
943+
944+
829945
@Test
830946
public void repartition() {
831947
// Shrinking number of partitions
@@ -1512,6 +1628,19 @@ public void collectAsync() throws Exception {
15121628
Assert.assertEquals(1, future.jobIds().size());
15131629
}
15141630

1631+
@Test
1632+
public void takeAsync() throws Exception {
1633+
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
1634+
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
1635+
JavaFutureAction<List<Integer>> future = rdd.takeAsync(1);
1636+
List<Integer> result = future.get();
1637+
Assert.assertEquals(1, result.size());
1638+
Assert.assertEquals((Integer) 1, result.get(0));
1639+
Assert.assertFalse(future.isCancelled());
1640+
Assert.assertTrue(future.isDone());
1641+
Assert.assertEquals(1, future.jobIds().size());
1642+
}
1643+
15151644
@Test
15161645
public void foreachAsync() throws Exception {
15171646
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);

streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.spark.streaming.dstream.DStream
3636
* [[org.apache.spark.streaming.api.java.JavaPairDStream]].
3737
*/
3838
class JavaDStream[T](val dstream: DStream[T])(implicit val classTag: ClassTag[T])
39-
extends JavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] {
39+
extends AbstractJavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] {
4040

4141
override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
4242

streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ import org.apache.spark.streaming._
3434
import org.apache.spark.streaming.api.java.JavaDStream._
3535
import org.apache.spark.streaming.dstream.DStream
3636

37+
/**
38+
* As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations
39+
* of JavaDStreamLike should extend this dummy abstract class instead of directly inheriting
40+
* from the trait. See SPARK-3266 for additional details.
41+
*/
42+
private[streaming]
43+
abstract class AbstractJavaDStreamLike[T, This <: JavaDStreamLike[T, This, R],
44+
R <: JavaRDDLike[T, R]] extends JavaDStreamLike[T, This, R]
45+
3746
trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]]
3847
extends Serializable {
3948
implicit val classTag: ClassTag[T]

streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ import org.apache.spark.streaming.dstream.DStream
4545
class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
4646
implicit val kManifest: ClassTag[K],
4747
implicit val vManifest: ClassTag[V])
48-
extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] {
48+
extends AbstractJavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] {
4949

5050
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
5151

0 commit comments

Comments
 (0)