Skip to content

Commit e8e2867

Browse files
committed
Updates based on Marcelo's review feedback
1 parent 7a1417f commit e8e2867

File tree

3 files changed

+30
-50
lines changed

3 files changed

+30
-50
lines changed

core/src/main/scala/org/apache/spark/FutureAction.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark
1919

20+
import java.util.Collections
2021
import java.util.concurrent.TimeUnit
2122

2223
import org.apache.spark.api.java.JavaFutureAction
@@ -285,12 +286,12 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S
285286

286287
override def isDone: Boolean = {
287288
// According to java.util.Future's Javadoc, this returns True if the task was completed,
288-
// whether that completion was due to succesful execution, an exception, or a cancellation.
289+
// whether that completion was due to successful execution, an exception, or a cancellation.
289290
futureAction.isCancelled || futureAction.isCompleted
290291
}
291292

292293
override def jobIds(): java.util.List[java.lang.Integer] = {
293-
new java.util.ArrayList(futureAction.jobIds.map(x => new Integer(x)).asJava)
294+
Collections.unmodifiableList(futureAction.jobIds.map(Integer.valueOf).asJava)
294295
}
295296

296297
private def getImpl(timeout: Duration): T = {
@@ -300,10 +301,10 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S
300301
case scala.util.Success(value) => converter(value)
301302
case Failure(exception) =>
302303
if (isCancelled) {
303-
throw new CancellationException("Job cancelled: ${exception.message}");
304+
throw new CancellationException("Job cancelled").initCause(exception)
304305
} else {
305306
// java.util.Future.get() wraps exceptions in ExecutionException
306-
throw new ExecutionException("Exception thrown by job: ", exception)
307+
throw new ExecutionException("Exception thrown by job", exception)
307308
}
308309
}
309310
}
@@ -313,7 +314,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S
313314
override def get(timeout: Long, unit: TimeUnit): T =
314315
getImpl(Duration.fromNanos(unit.toNanos(timeout)))
315316

316-
override def cancel(mayInterruptIfRunning: Boolean): Boolean = {
317+
override def cancel(mayInterruptIfRunning: Boolean): Boolean = synchronized {
317318
if (isDone) {
318319
// According to java.util.Future's Javadoc, this should return false if the task is completed.
319320
false

core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ import java.util.{Comparator, List => JList, Iterator => JIterator}
2121
import java.lang.{Iterable => JIterable, Long => JLong}
2222

2323
import scala.collection.JavaConversions._
24+
import scala.collection.JavaConverters._
2425
import scala.reflect.ClassTag
2526

2627
import com.google.common.base.Optional
2728
import org.apache.hadoop.io.compress.CompressionCodec
2829

2930
import org.apache.spark._
31+
import org.apache.spark.SparkContext._
3032
import org.apache.spark.annotation.Experimental
3133
import org.apache.spark.api.java.JavaPairRDD._
3234
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
@@ -578,34 +580,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
578580
* future for counting the number of elements in this RDD.
579581
*/
580582
def countAsync(): JavaFutureAction[JLong] = {
581-
import org.apache.spark.SparkContext._
582-
new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), x => new JLong(x))
583+
new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), JLong.valueOf)
583584
}
584585

585586
/**
586587
* The asynchronous version of `collect`, which returns a future for
587588
* retrieving an array containing all of the elements in this RDD.
588589
*/
589590
def collectAsync(): JavaFutureAction[JList[T]] = {
590-
import org.apache.spark.SparkContext._
591-
new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => new java.util.ArrayList(x))
591+
new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => x.asJava)
592592
}
593593

594594
/**
595595
* The asynchronous version of the `take` action, which returns a
596596
* future for retrieving the first `num` elements of this RDD.
597597
*/
598598
def takeAsync(num: Int): JavaFutureAction[JList[T]] = {
599-
import org.apache.spark.SparkContext._
600-
new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => new java.util.ArrayList(x))
599+
new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => x.asJava)
601600
}
602601

603602
/**
604603
* The asynchronous version of the `foreach` action, which
605604
* applies a function f to all the elements of this RDD.
606605
*/
607606
def foreachAsync(f: VoidFunction[T]): JavaFutureAction[Void] = {
608-
import org.apache.spark.SparkContext._
609607
new JavaFutureActionWrapper[Unit, Void](rdd.foreachAsync(x => f.call(x)),
610608
{ x => null.asInstanceOf[Void] })
611609
}
@@ -615,7 +613,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
615613
* applies a function f to each partition of this RDD.
616614
*/
617615
def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = {
618-
import org.apache.spark.SparkContext._
619616
new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)),
620617
{ x => null.asInstanceOf[Void] })
621618
}

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import com.google.common.collect.Iterators;
3131
import com.google.common.collect.Lists;
3232
import com.google.common.collect.Maps;
33+
import com.google.common.base.Throwables;
3334
import com.google.common.base.Optional;
3435
import com.google.common.base.Charsets;
3536
import com.google.common.io.Files;
@@ -1306,21 +1307,6 @@ public void collectUnderlyingScalaRDD() {
13061307
Assert.assertEquals(data.size(), collected.length);
13071308
}
13081309

1309-
private static final class IdentityWithDelay<T> implements Function<T, T> {
1310-
1311-
final int delayMillis;
1312-
1313-
IdentityWithDelay(int delayMillis) {
1314-
this.delayMillis = delayMillis;
1315-
}
1316-
1317-
@Override
1318-
public T call(T x) throws Exception {
1319-
Thread.sleep(delayMillis);
1320-
return x;
1321-
}
1322-
}
1323-
13241310
private static final class BuggyMapFunction<T> implements Function<T, T> {
13251311

13261312
@Override
@@ -1333,62 +1319,59 @@ public T call(T x) throws Exception {
13331319
public void collectAsync() throws Exception {
13341320
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
13351321
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
1336-
JavaFutureAction<List<Integer>> future =
1337-
rdd.map(new IdentityWithDelay<Integer>(200)).collectAsync();
1338-
Assert.assertFalse(future.isCancelled());
1339-
Assert.assertFalse(future.isDone());
1322+
JavaFutureAction<List<Integer>> future = rdd.collectAsync();
13401323
List<Integer> result = future.get();
1341-
Assert.assertEquals(result, data);
1324+
Assert.assertEquals(data, result);
13421325
Assert.assertFalse(future.isCancelled());
13431326
Assert.assertTrue(future.isDone());
1344-
Assert.assertEquals(future.jobIds().size(), 1);
1327+
Assert.assertEquals(1, future.jobIds().size());
13451328
}
13461329

13471330
@Test
13481331
public void foreachAsync() throws Exception {
13491332
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
13501333
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
1351-
JavaFutureAction<Void> future = rdd.map(new IdentityWithDelay<Integer>(200)).foreachAsync(
1334+
JavaFutureAction<Void> future = rdd.foreachAsync(
13521335
new VoidFunction<Integer>() {
13531336
@Override
13541337
public void call(Integer integer) throws Exception {
13551338
// intentionally left blank.
13561339
}
13571340
}
13581341
);
1359-
Assert.assertFalse(future.isCancelled());
1360-
Assert.assertFalse(future.isDone());
13611342
future.get();
13621343
Assert.assertFalse(future.isCancelled());
13631344
Assert.assertTrue(future.isDone());
1364-
Assert.assertEquals(future.jobIds().size(), 1);
1345+
Assert.assertEquals(1, future.jobIds().size());
13651346
}
13661347

13671348
@Test
13681349
public void countAsync() throws Exception {
13691350
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
13701351
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
1371-
JavaFutureAction<Long> future = rdd.map(new IdentityWithDelay<Integer>(200)).countAsync();
1372-
Assert.assertFalse(future.isCancelled());
1373-
Assert.assertFalse(future.isDone());
1352+
JavaFutureAction<Long> future = rdd.countAsync();
13741353
long count = future.get();
1375-
Assert.assertEquals(count, data.size());
1354+
Assert.assertEquals(data.size(), count);
13761355
Assert.assertFalse(future.isCancelled());
13771356
Assert.assertTrue(future.isDone());
1378-
Assert.assertEquals(future.jobIds().size(), 1);
1357+
Assert.assertEquals(1, future.jobIds().size());
13791358
}
13801359

13811360
@Test
13821361
public void testAsyncActionCancellation() throws Exception {
13831362
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
13841363
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
1385-
JavaFutureAction<Long> future = rdd.map(new IdentityWithDelay<Integer>(200)).countAsync();
1386-
Thread.sleep(200);
1364+
JavaFutureAction<Void> future = rdd.foreachAsync(new VoidFunction<Integer>() {
1365+
@Override
1366+
public void call(Integer integer) throws Exception {
1367+
Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled.
1368+
}
1369+
});
13871370
future.cancel(true);
13881371
Assert.assertTrue(future.isCancelled());
13891372
Assert.assertTrue(future.isDone());
13901373
try {
1391-
long count = future.get(2000, TimeUnit.MILLISECONDS);
1374+
future.get(2000, TimeUnit.MILLISECONDS);
13921375
Assert.fail("Expected future.get() for cancelled job to throw CancellationException");
13931376
} catch (CancellationException ignored) {
13941377
// pass
@@ -1400,12 +1383,11 @@ public void testAsyncActionErrorWrapping() throws Exception {
14001383
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
14011384
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
14021385
JavaFutureAction<Long> future = rdd.map(new BuggyMapFunction<Integer>()).countAsync();
1403-
Thread.sleep(200);
14041386
try {
1405-
long count = future.get(2000, TimeUnit.MILLISECONDS);
1387+
long count = future.get(2, TimeUnit.SECONDS);
14061388
Assert.fail("Expected future.get() for failed job to throw ExcecutionException");
1407-
} catch (ExecutionException ignored) {
1408-
// pass
1389+
} catch (ExecutionException ee) {
1390+
Assert.assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!"));
14091391
}
14101392
Assert.assertTrue(future.isDone());
14111393
}

0 commit comments

Comments
 (0)