diff --git a/core/src/main/java/org/apache/iceberg/TableProperties.java b/core/src/main/java/org/apache/iceberg/TableProperties.java index 56d921b918a6..38c80a4d6dc1 100644 --- a/core/src/main/java/org/apache/iceberg/TableProperties.java +++ b/core/src/main/java/org/apache/iceberg/TableProperties.java @@ -188,4 +188,7 @@ private TableProperties() { public static final String MERGE_CARDINALITY_CHECK_ENABLED = "write.merge.cardinality-check.enabled"; public static final boolean MERGE_CARDINALITY_CHECK_ENABLED_DEFAULT = true; + + public static final String SPARK_BATCH_SCAN_POOL_SIZE = "read.spark.scan.pool-size"; + public static final int SPARK_BATCH_SCAN_POOL_SIZE_DEFAULT = 6; } diff --git a/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java b/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java index 45a13f2762c4..c5c26365c260 100644 --- a/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java +++ b/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java @@ -21,9 +21,13 @@ import java.io.IOException; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.ForkJoinPool; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -95,6 +99,8 @@ class Reader implements DataSourceReader, SupportsScanColumnarBatch, SupportsPus private Filter[] pushedFilters = NO_FILTERS; private final boolean localityPreferred; private final int batchSize; + private final int poolSize; + private final ForkJoinPool pool; // lazy variables private Schema schema = null; @@ -157,6 +163,9 @@ class Reader implements DataSourceReader, SupportsScanColumnarBatch, SupportsPus this.batchSize = options.get(SparkReadOptions.VECTORIZATION_BATCH_SIZE).map(Integer::parseInt).orElseGet(() -> PropertyUtil.propertyAsInt(table.properties(), TableProperties.PARQUET_BATCH_SIZE, TableProperties.PARQUET_BATCH_SIZE_DEFAULT)); + this.poolSize = options.getInt(TableProperties.SPARK_BATCH_SCAN_POOL_SIZE, + TableProperties.SPARK_BATCH_SCAN_POOL_SIZE_DEFAULT); + this.pool = new ForkJoinPool(this.poolSize); } private Schema lazySchema() { @@ -193,6 +202,7 @@ public StructType readSchema() { /** * This is called in the Spark Driver when data is to be materialized into {@link ColumnarBatch} */ + @SuppressWarnings("checkstyle:LocalVariableName") @Override public List> planBatchInputPartitions() { Preconditions.checkState(enableBatchRead(), "Batched reads not enabled"); @@ -205,35 +215,65 @@ public List> planBatchInputPartitions() { // broadcast the table metadata as input partitions will be sent to executors Broadcast tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); - List> readTasks = Lists.newArrayList(); - for (CombinedScanTask task : tasks()) { - readTasks.add(new ReadTask<>( - task, tableBroadcast, expectedSchemaString, caseSensitive, - localityPreferred, new BatchReaderFactory(batchSize))); + int taskSize = tasks().size(); + InputPartition[] readTasks = new InputPartition[taskSize]; + Long startTime = System.currentTimeMillis(); + try { + pool.submit(() -> IntStream.range(0, taskSize).parallel() + .mapToObj(taskId -> { + LOG.trace("The size of scanTasks is {}, current taskId is {}, current thread id is {}", + taskSize, taskId, Thread.currentThread().getName()); + readTasks[taskId] = new ReadTask<>( + tasks().get(taskId), tableBroadcast, expectedSchemaString, caseSensitive, + localityPreferred, new BatchReaderFactory(batchSize)); + return true; + }).collect(Collectors.toList())).get(); + } catch (Exception e) { + LOG.error("Fail to construct ReadTask with thread size = {}, and the size of scanTasks is {}", + poolSize, taskSize, e); + System.exit(-1); } - LOG.info("Batching input partitions with {} tasks.", readTasks.size()); - - return readTasks; + Long endTime = System.currentTimeMillis(); + LOG.info("Batching input partitions with {} tasks.", readTasks.length); + LOG.info("It took {} s to construct {} readTasks with localityPreferred = {}.", (endTime - startTime) / 1000, + taskSize, localityPreferred); + return Arrays.asList(readTasks.clone()); } /** * This is called in the Spark Driver when data is to be materialized into {@link InternalRow} */ + @SuppressWarnings("checkstyle:LocalVariableName") @Override public List> planInputPartitions() { String expectedSchemaString = SchemaParser.toJson(lazySchema()); // broadcast the table metadata as input partitions will be sent to executors Broadcast
tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); - - List> readTasks = Lists.newArrayList(); - for (CombinedScanTask task : tasks()) { - readTasks.add(new ReadTask<>( - task, tableBroadcast, expectedSchemaString, caseSensitive, - localityPreferred, InternalRowReaderFactory.INSTANCE)); + List scanTasks = tasks(); + int taskSize = scanTasks.size(); + InputPartition[] readTasks = new InputPartition[taskSize]; + Long startTime = System.currentTimeMillis(); + try { + pool.submit(() -> IntStream.range(0, taskSize).parallel() + .mapToObj(taskId -> { + LOG.trace("The size of scanTasks is {}, current taskId is {}, current thread name is {}", + taskSize, taskId, Thread.currentThread().getName()); + readTasks[taskId] = new ReadTask<>( + scanTasks.get(taskId), tableBroadcast, expectedSchemaString, caseSensitive, + localityPreferred, InternalRowReaderFactory.INSTANCE); + return true; + }).collect(Collectors.toList())).get(); + } catch (Exception e) { + LOG.error("Fail to construct ReadTask with thread size = {}, and the size of scanTasks is {}", + poolSize, taskSize, e); + System.exit(-1); } - return readTasks; + Long endTime = System.currentTimeMillis(); + LOG.info("It took {} s to construct {} readTasks with localityPreferred = {}.", (endTime - startTime) / 1000, + taskSize, localityPreferred); + return Arrays.asList(readTasks.clone()); } @Override diff --git a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java index 8cf755b82eb9..fd22607543da 100644 --- a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java +++ b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java @@ -24,7 +24,9 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.ForkJoinPool; import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.iceberg.CombinedScanTask; import org.apache.iceberg.FileFormat; import org.apache.iceberg.FileScanTask; @@ -33,6 +35,7 @@ import org.apache.iceberg.SerializableTable; import org.apache.iceberg.SnapshotSummary; import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.hadoop.HadoopInputFile; import org.apache.iceberg.hadoop.Util; @@ -70,6 +73,8 @@ abstract class SparkBatchScan implements Scan, Batch, SupportsReportStatistics { private final List filterExpressions; private final int batchSize; private final CaseInsensitiveStringMap options; + private final int poolSize; + private final ForkJoinPool pool; // lazy variables private StructType readSchema = null; @@ -84,6 +89,9 @@ abstract class SparkBatchScan implements Scan, Batch, SupportsReportStatistics { this.localityPreferred = Spark3Util.isLocalityEnabled(table.io(), table.location(), options); this.batchSize = Spark3Util.batchSize(table.properties(), options); this.options = options; + this.poolSize = options.getInt(TableProperties.SPARK_BATCH_SCAN_POOL_SIZE, + TableProperties.SPARK_BATCH_SCAN_POOL_SIZE_DEFAULT); + this.pool = new ForkJoinPool(this.poolSize); } protected Table table() { @@ -123,6 +131,7 @@ public StructType readSchema() { return readSchema; } + @SuppressWarnings("checkstyle:LocalVariableName") @Override public InputPartition[] planInputPartitions() { String expectedSchemaString = SchemaParser.toJson(expectedSchema); @@ -131,13 +140,28 @@ public InputPartition[] planInputPartitions() { Broadcast
tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); List scanTasks = tasks(); - InputPartition[] readTasks = new InputPartition[scanTasks.size()]; - for (int i = 0; i < scanTasks.size(); i++) { - readTasks[i] = new ReadTask( - scanTasks.get(i), tableBroadcast, expectedSchemaString, - caseSensitive, localityPreferred); + int taskSize = scanTasks.size(); + InputPartition[] readTasks = new InputPartition[taskSize]; + long startTime = System.currentTimeMillis(); + + try { + pool.submit(() -> IntStream.range(0, taskSize).parallel() + .mapToObj(taskId -> { + LOG.trace("The size of scanTasks is {}, current taskId is {}, current thread id is {}", + taskSize, taskId, Thread.currentThread().getName()); + readTasks[taskId] = new ReadTask(scanTasks.get(taskId), tableBroadcast, + expectedSchemaString, caseSensitive, localityPreferred); + return true; + }).collect(Collectors.toList())).get(); + } catch (Exception e) { + LOG.error("Fail to construct ReadTask with thread size = {}, and the size of scanTasks is {}", + poolSize, taskSize, e); + System.exit(-1); } + long endTime = System.currentTimeMillis(); + LOG.info("It took {} s to construct {} readTasks with localityPreferred = {}.", (startTime - endTime) / 1000, + taskSize, localityPreferred); return readTasks; }