Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions core/src/main/java/org/apache/iceberg/TableProperties.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,7 @@ private TableProperties() {

public static final String MERGE_CARDINALITY_CHECK_ENABLED = "write.merge.cardinality-check.enabled";
public static final boolean MERGE_CARDINALITY_CHECK_ENABLED_DEFAULT = true;

public static final String SPARK_BATCH_SCAN_POOL_SIZE = "read.spark.scan.pool-size";
public static final int SPARK_BATCH_SCAN_POOL_SIZE_DEFAULT = 6;
}
70 changes: 55 additions & 15 deletions spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@

import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ForkJoinPool;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
Expand Down Expand Up @@ -95,6 +99,8 @@ class Reader implements DataSourceReader, SupportsScanColumnarBatch, SupportsPus
private Filter[] pushedFilters = NO_FILTERS;
private final boolean localityPreferred;
private final int batchSize;
private final int poolSize;
private final ForkJoinPool pool;

// lazy variables
private Schema schema = null;
Expand Down Expand Up @@ -157,6 +163,9 @@ class Reader implements DataSourceReader, SupportsScanColumnarBatch, SupportsPus
this.batchSize = options.get(SparkReadOptions.VECTORIZATION_BATCH_SIZE).map(Integer::parseInt).orElseGet(() ->
PropertyUtil.propertyAsInt(table.properties(),
TableProperties.PARQUET_BATCH_SIZE, TableProperties.PARQUET_BATCH_SIZE_DEFAULT));
this.poolSize = options.getInt(TableProperties.SPARK_BATCH_SCAN_POOL_SIZE,
TableProperties.SPARK_BATCH_SCAN_POOL_SIZE_DEFAULT);
this.pool = new ForkJoinPool(this.poolSize);
}

private Schema lazySchema() {
Expand Down Expand Up @@ -193,6 +202,7 @@ public StructType readSchema() {
/**
* This is called in the Spark Driver when data is to be materialized into {@link ColumnarBatch}
*/
@SuppressWarnings("checkstyle:LocalVariableName")
@Override
public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() {
Preconditions.checkState(enableBatchRead(), "Batched reads not enabled");
Expand All @@ -205,35 +215,65 @@ public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() {
// broadcast the table metadata as input partitions will be sent to executors
Broadcast<Table> tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table));

List<InputPartition<ColumnarBatch>> readTasks = Lists.newArrayList();
for (CombinedScanTask task : tasks()) {
readTasks.add(new ReadTask<>(
task, tableBroadcast, expectedSchemaString, caseSensitive,
localityPreferred, new BatchReaderFactory(batchSize)));
int taskSize = tasks().size();
InputPartition<ColumnarBatch>[] readTasks = new InputPartition[taskSize];
Long startTime = System.currentTimeMillis();
try {
pool.submit(() -> IntStream.range(0, taskSize).parallel()
Copy link
Contributor

@kbendick kbendick Jul 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the construction of all read tasks be done in a single submit to the thread pool? I don’t see any way to slow the parallelism down here so as to not potentially overwhelm the name node.

For example, I would have expected that each of the ranges in the int stream were submitted to the pool invidually, so that tasks queue up waiting for their turn. Here, it looks like the parallelism is rather unbounded. Totally open to reading this wrong (it is Sunday for me after all!).

.mapToObj(taskId -> {
LOG.trace("The size of scanTasks is {}, current taskId is {}, current thread id is {}",
taskSize, taskId, Thread.currentThread().getName());
readTasks[taskId] = new ReadTask<>(
tasks().get(taskId), tableBroadcast, expectedSchemaString, caseSensitive,
localityPreferred, new BatchReaderFactory(batchSize));
return true;
}).collect(Collectors.toList())).get();
} catch (Exception e) {
LOG.error("Fail to construct ReadTask with thread size = {}, and the size of scanTasks is {}",
poolSize, taskSize, e);
System.exit(-1);
}
LOG.info("Batching input partitions with {} tasks.", readTasks.size());

return readTasks;
Long endTime = System.currentTimeMillis();
LOG.info("Batching input partitions with {} tasks.", readTasks.length);
LOG.info("It took {} s to construct {} readTasks with localityPreferred = {}.", (endTime - startTime) / 1000,
taskSize, localityPreferred);
return Arrays.asList(readTasks.clone());
}

/**
* This is called in the Spark Driver when data is to be materialized into {@link InternalRow}
*/
@SuppressWarnings("checkstyle:LocalVariableName")
@Override
public List<InputPartition<InternalRow>> planInputPartitions() {
String expectedSchemaString = SchemaParser.toJson(lazySchema());

// broadcast the table metadata as input partitions will be sent to executors
Broadcast<Table> tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table));

List<InputPartition<InternalRow>> readTasks = Lists.newArrayList();
for (CombinedScanTask task : tasks()) {
readTasks.add(new ReadTask<>(
task, tableBroadcast, expectedSchemaString, caseSensitive,
localityPreferred, InternalRowReaderFactory.INSTANCE));
List<CombinedScanTask> scanTasks = tasks();
int taskSize = scanTasks.size();
InputPartition<InternalRow>[] readTasks = new InputPartition[taskSize];
Long startTime = System.currentTimeMillis();
try {
pool.submit(() -> IntStream.range(0, taskSize).parallel()
.mapToObj(taskId -> {
LOG.trace("The size of scanTasks is {}, current taskId is {}, current thread name is {}",
taskSize, taskId, Thread.currentThread().getName());
readTasks[taskId] = new ReadTask<>(
scanTasks.get(taskId), tableBroadcast, expectedSchemaString, caseSensitive,
localityPreferred, InternalRowReaderFactory.INSTANCE);
return true;
}).collect(Collectors.toList())).get();
} catch (Exception e) {
LOG.error("Fail to construct ReadTask with thread size = {}, and the size of scanTasks is {}",
poolSize, taskSize, e);
System.exit(-1);
Copy link
Contributor

@kbendick kbendick Jul 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid explicitly calling System.exit and instead let the exception bubble up (possibly catching it and then rethrowing it with additional information added or as another exception type). This would make it easier for end users to track down their exceptions, particularly when working in a notebook where there's limited space to display stack traces already.

Is there a specific reason you chose to call System.exit that possibly I'm not aware of?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing this! I wanted to throw exception out at first, and if I did this, I also need to change code of
org.apache.spark.sql.connector.read.Batch and so on which is outside of iceberg. Here is the err msg:

planInputPartitions()' in 'org.apache.iceberg.spark.source.SparkBatchScan' clashes with 'planInputPartitions()' in 'org.apache.spark.sql.connector.read.Batch'; overridden method does not throw 'java.lang.Exception'

}

return readTasks;
Long endTime = System.currentTimeMillis();
LOG.info("It took {} s to construct {} readTasks with localityPreferred = {}.", (endTime - startTime) / 1000,
taskSize, localityPreferred);
return Arrays.asList(readTasks.clone());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ForkJoinPool;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.iceberg.CombinedScanTask;
import org.apache.iceberg.FileFormat;
import org.apache.iceberg.FileScanTask;
Expand All @@ -33,6 +35,7 @@
import org.apache.iceberg.SerializableTable;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.hadoop.HadoopInputFile;
import org.apache.iceberg.hadoop.Util;
Expand Down Expand Up @@ -70,6 +73,8 @@ abstract class SparkBatchScan implements Scan, Batch, SupportsReportStatistics {
private final List<Expression> filterExpressions;
private final int batchSize;
private final CaseInsensitiveStringMap options;
private final int poolSize;
private final ForkJoinPool pool;

// lazy variables
private StructType readSchema = null;
Expand All @@ -84,6 +89,9 @@ abstract class SparkBatchScan implements Scan, Batch, SupportsReportStatistics {
this.localityPreferred = Spark3Util.isLocalityEnabled(table.io(), table.location(), options);
this.batchSize = Spark3Util.batchSize(table.properties(), options);
this.options = options;
this.poolSize = options.getInt(TableProperties.SPARK_BATCH_SCAN_POOL_SIZE,
TableProperties.SPARK_BATCH_SCAN_POOL_SIZE_DEFAULT);
this.pool = new ForkJoinPool(this.poolSize);
}

protected Table table() {
Expand Down Expand Up @@ -123,6 +131,7 @@ public StructType readSchema() {
return readSchema;
}

@SuppressWarnings("checkstyle:LocalVariableName")
@Override
public InputPartition[] planInputPartitions() {
String expectedSchemaString = SchemaParser.toJson(expectedSchema);
Expand All @@ -131,13 +140,28 @@ public InputPartition[] planInputPartitions() {
Broadcast<Table> tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table));

List<CombinedScanTask> scanTasks = tasks();
InputPartition[] readTasks = new InputPartition[scanTasks.size()];
for (int i = 0; i < scanTasks.size(); i++) {
readTasks[i] = new ReadTask(
scanTasks.get(i), tableBroadcast, expectedSchemaString,
caseSensitive, localityPreferred);
int taskSize = scanTasks.size();
InputPartition[] readTasks = new InputPartition[taskSize];
long startTime = System.currentTimeMillis();

try {
pool.submit(() -> IntStream.range(0, taskSize).parallel()
.mapToObj(taskId -> {
LOG.trace("The size of scanTasks is {}, current taskId is {}, current thread id is {}",
taskSize, taskId, Thread.currentThread().getName());
readTasks[taskId] = new ReadTask(scanTasks.get(taskId), tableBroadcast,
expectedSchemaString, caseSensitive, localityPreferred);
return true;
}).collect(Collectors.toList())).get();
} catch (Exception e) {
LOG.error("Fail to construct ReadTask with thread size = {}, and the size of scanTasks is {}",
poolSize, taskSize, e);
System.exit(-1);
}

long endTime = System.currentTimeMillis();
LOG.info("It took {} s to construct {} readTasks with localityPreferred = {}.", (startTime - endTime) / 1000,
taskSize, localityPreferred);
return readTasks;
}

Expand Down