diff --git a/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java b/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java index 45a13f2762c4..7a9b73e82d43 100644 --- a/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java +++ b/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Map; @@ -52,6 +53,8 @@ import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.util.PropertyUtil; import org.apache.iceberg.util.TableScanUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.RuntimeConfig; @@ -205,15 +208,18 @@ public List> planBatchInputPartitions() { // broadcast the table metadata as input partitions will be sent to executors Broadcast tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); - List> readTasks = Lists.newArrayList(); - for (CombinedScanTask task : tasks()) { - readTasks.add(new ReadTask<>( - task, tableBroadcast, expectedSchemaString, caseSensitive, - localityPreferred, new BatchReaderFactory(batchSize))); - } - LOG.info("Batching input partitions with {} tasks.", readTasks.size()); + List scanTasks = tasks(); + InputPartition[] readTasks = new InputPartition[scanTasks.size()]; + + Tasks.range(readTasks.length) + .stopOnFailure() + .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null) + .run(index -> readTasks[index] = new ReadTask<>( + scanTasks.get(index), tableBroadcast, expectedSchemaString, caseSensitive, + localityPreferred, new BatchReaderFactory(batchSize))); + LOG.info("Batching input partitions with {} tasks.", readTasks.length); - return readTasks; + return Arrays.asList(readTasks); } /** @@ -226,14 +232,17 @@ public List> planInputPartitions() { // broadcast the table metadata as input partitions will be sent to executors Broadcast
tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); - List> readTasks = Lists.newArrayList(); - for (CombinedScanTask task : tasks()) { - readTasks.add(new ReadTask<>( - task, tableBroadcast, expectedSchemaString, caseSensitive, - localityPreferred, InternalRowReaderFactory.INSTANCE)); - } + List scanTasks = tasks(); + InputPartition[] readTasks = new InputPartition[scanTasks.size()]; + + Tasks.range(readTasks.length) + .stopOnFailure() + .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null) + .run(index -> readTasks[index] = new ReadTask<>( + scanTasks.get(index), tableBroadcast, expectedSchemaString, caseSensitive, + localityPreferred, InternalRowReaderFactory.INSTANCE)); - return readTasks; + return Arrays.asList(readTasks); } @Override diff --git a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java index 8cf755b82eb9..baf40341cece 100644 --- a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java +++ b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java @@ -40,6 +40,8 @@ import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.util.PropertyUtil; import org.apache.iceberg.util.TableScanUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.RuntimeConfig; @@ -132,11 +134,13 @@ public InputPartition[] planInputPartitions() { List scanTasks = tasks(); InputPartition[] readTasks = new InputPartition[scanTasks.size()]; - for (int i = 0; i < scanTasks.size(); i++) { - readTasks[i] = new ReadTask( - scanTasks.get(i), tableBroadcast, expectedSchemaString, - caseSensitive, localityPreferred); - } + + Tasks.range(readTasks.length) + .stopOnFailure() + .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null) + .run(index -> readTasks[index] = new ReadTask( + scanTasks.get(index), tableBroadcast, expectedSchemaString, + caseSensitive, localityPreferred)); return readTasks; } diff --git a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java index 84fdbd288de7..e08c5a1287bd 100644 --- a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java +++ b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java @@ -52,6 +52,8 @@ import org.apache.iceberg.util.PropertyUtil; import org.apache.iceberg.util.SnapshotUtil; import org.apache.iceberg.util.TableScanUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.connector.read.InputPartition; @@ -136,11 +138,12 @@ public InputPartition[] planInputPartitions(Offset start, Offset end) { TableScanUtil.planTasks(splitTasks, splitSize, splitLookback, splitOpenFileCost)); InputPartition[] readTasks = new InputPartition[combinedScanTasks.size()]; - for (int i = 0; i < combinedScanTasks.size(); i++) { - readTasks[i] = new ReadTask( - combinedScanTasks.get(i), tableBroadcast, expectedSchema, - caseSensitive, localityPreferred); - } + Tasks.range(readTasks.length) + .stopOnFailure() + .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null) + .run(index -> readTasks[index] = new ReadTask( + combinedScanTasks.get(index), tableBroadcast, expectedSchema, + caseSensitive, localityPreferred)); return readTasks; }