diff --git a/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java b/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java index 3db5beee26..9f40a67db1 100644 --- a/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java +++ b/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Map; @@ -57,6 +58,8 @@ import org.apache.iceberg.spark.SparkUtil; import org.apache.iceberg.util.PropertyUtil; import org.apache.iceberg.util.TableScanUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.RuntimeConfig; import org.apache.spark.sql.SparkSession; @@ -233,15 +236,19 @@ public List> planBatchInputPartitions() { ValidationException.check(tasks().stream().noneMatch(TableScanUtil::hasDeletes), "Cannot scan table %s: cannot apply required delete files", table); - List> readTasks = Lists.newArrayList(); - for (CombinedScanTask task : tasks()) { - readTasks.add(new ReadTask<>( - task, tableSchemaString, expectedSchemaString, nameMappingString, io, encryptionManager, caseSensitive, - localityPreferred, new BatchReaderFactory(batchSize), ignoreFileFieldIds)); - } - LOG.info("Batching input partitions with {} tasks.", readTasks.size()); + List scanTasks = tasks(); + InputPartition[] readTasks = new InputPartition[scanTasks.size()]; + + Tasks.range(readTasks.length) + .stopOnFailure() + .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null) + .run(index -> readTasks[index] = new ReadTask<>( + scanTasks.get(index), tableSchemaString, expectedSchemaString, nameMappingString, io, + encryptionManager, caseSensitive, + localityPreferred, new BatchReaderFactory(batchSize), ignoreFileFieldIds)); + LOG.info("Batching input partitions with {} tasks.", readTasks.length); - return readTasks; + return Arrays.asList(readTasks); } /** @@ -257,14 +264,18 @@ public List> planInputPartitions() { READ_ORC_IGNORE_FILE_FIELD_IDS, READ_ORC_IGNORE_FILE_FIELD_IDS_DEFAULT); - List> readTasks = Lists.newArrayList(); - for (CombinedScanTask task : tasks()) { - readTasks.add(new ReadTask<>( - task, tableSchemaString, expectedSchemaString, nameMappingString, io, encryptionManager, caseSensitive, - localityPreferred, InternalRowReaderFactory.INSTANCE, ignoreFileFieldIds)); - } + List scanTasks = tasks(); + InputPartition[] readTasks = new InputPartition[scanTasks.size()]; + + Tasks.range(readTasks.length) + .stopOnFailure() + .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null) + .run(index -> readTasks[index] = new ReadTask<>( + scanTasks.get(index), tableSchemaString, expectedSchemaString, nameMappingString, io, + encryptionManager, caseSensitive, + localityPreferred, InternalRowReaderFactory.INSTANCE, ignoreFileFieldIds)); - return readTasks; + return Arrays.asList(readTasks); } @Override diff --git a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java index 817651d030..e9bcc70352 100644 --- a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java +++ b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java @@ -46,6 +46,8 @@ import org.apache.iceberg.spark.SparkUtil; import org.apache.iceberg.util.PropertyUtil; import org.apache.iceberg.util.TableScanUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.RuntimeConfig; import org.apache.spark.sql.SparkSession; @@ -145,11 +147,13 @@ public InputPartition[] planInputPartitions() { List scanTasks = tasks(); InputPartition[] readTasks = new InputPartition[scanTasks.size()]; - for (int i = 0; i < scanTasks.size(); i++) { - readTasks[i] = new ReadTask( - scanTasks.get(i), tableSchemaString, expectedSchemaString, nameMappingString, io, encryptionManager, - caseSensitive, localityPreferred, ignoreFileFieldIds); - } + + Tasks.range(readTasks.length) + .stopOnFailure() + .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null) + .run(index -> readTasks[index] = new ReadTask( + scanTasks.get(index), tableSchemaString, expectedSchemaString, nameMappingString, io, encryptionManager, + caseSensitive, localityPreferred, ignoreFileFieldIds)); return readTasks; }