diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieMergeOnReadRDD.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieMergeOnReadRDD.scala index f26cd881e1ea0..ffe2c92829698 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieMergeOnReadRDD.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieMergeOnReadRDD.scala @@ -35,6 +35,8 @@ import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.{Partition, SerializableWritable, SparkContext, TaskContext} +import java.io.Closeable + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Try @@ -58,7 +60,7 @@ class HoodieMergeOnReadRDD(@transient sc: SparkContext, } override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val mergeOnReadPartition = split.asInstanceOf[HoodieMergeOnReadPartition] - mergeOnReadPartition.split match { + val iter = mergeOnReadPartition.split match { case dataFileOnlySplit if dataFileOnlySplit.logPaths.isEmpty => read(dataFileOnlySplit.dataFile.get, requiredSchemaFileReader) case logFileOnlySplit if logFileOnlySplit.dataFile.isEmpty => @@ -84,6 +86,12 @@ class HoodieMergeOnReadRDD(@transient sc: SparkContext, s"spark partition Index: ${mergeOnReadPartition.index}" + s"merge type: ${mergeOnReadPartition.split.mergeType}") } + if (iter.isInstanceOf[Closeable]) { + // register a callback to close logScanner which will be executed on task completion. + // when tasks finished, this method will be called, and release resources. + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.asInstanceOf[Closeable].close())) + } + iter } override protected def getPartitions: Array[Partition] = { @@ -112,7 +120,7 @@ class HoodieMergeOnReadRDD(@transient sc: SparkContext, private def logFileIterator(split: HoodieMergeOnReadFileSplit, config: Configuration): Iterator[InternalRow] = - new Iterator[InternalRow] { + new Iterator[InternalRow] with Closeable { private val tableAvroSchema = new Schema.Parser().parse(tableState.tableAvroSchema) private val requiredAvroSchema = new Schema.Parser().parse(tableState.requiredAvroSchema) private val requiredFieldPosition = @@ -121,7 +129,8 @@ class HoodieMergeOnReadRDD(@transient sc: SparkContext, private val recordBuilder = new GenericRecordBuilder(requiredAvroSchema) private val deserializer = HoodieAvroDeserializer(requiredAvroSchema, tableState.requiredStructSchema) private val unsafeProjection = UnsafeProjection.create(tableState.requiredStructSchema) - private val logRecords = HoodieMergeOnReadRDD.scanLog(split, tableAvroSchema, config).getRecords + private var logScanner = HoodieMergeOnReadRDD.scanLog(split, tableAvroSchema, config) + private val logRecords = logScanner.getRecords private val logRecordsKeyIterator = logRecords.keySet().iterator().asScala private var recordToLoad: InternalRow = _ @@ -146,12 +155,22 @@ class HoodieMergeOnReadRDD(@transient sc: SparkContext, override def next(): InternalRow = { recordToLoad } + + override def close(): Unit = { + if (logScanner != null) { + try { + logScanner.close() + } finally { + logScanner = null + } + } + } } private def skipMergeFileIterator(split: HoodieMergeOnReadFileSplit, baseFileIterator: Iterator[InternalRow], config: Configuration): Iterator[InternalRow] = - new Iterator[InternalRow] { + new Iterator[InternalRow] with Closeable { private val tableAvroSchema = new Schema.Parser().parse(tableState.tableAvroSchema) private val requiredAvroSchema = new Schema.Parser().parse(tableState.requiredAvroSchema) private val requiredFieldPosition = @@ -160,7 +179,8 @@ class HoodieMergeOnReadRDD(@transient sc: SparkContext, private val recordBuilder = new GenericRecordBuilder(requiredAvroSchema) private val deserializer = HoodieAvroDeserializer(requiredAvroSchema, tableState.requiredStructSchema) private val unsafeProjection = UnsafeProjection.create(tableState.requiredStructSchema) - private val logRecords = HoodieMergeOnReadRDD.scanLog(split, tableAvroSchema, config).getRecords + private var logScanner = HoodieMergeOnReadRDD.scanLog(split, tableAvroSchema, config) + private val logRecords = logScanner.getRecords private val logRecordsKeyIterator = logRecords.keySet().iterator().asScala private var recordToLoad: InternalRow = _ @@ -192,12 +212,22 @@ class HoodieMergeOnReadRDD(@transient sc: SparkContext, override def next(): InternalRow = { recordToLoad } + + override def close(): Unit = { + if (logScanner != null) { + try { + logScanner.close() + } finally { + logScanner = null + } + } + } } private def payloadCombineFileIterator(split: HoodieMergeOnReadFileSplit, baseFileIterator: Iterator[InternalRow], config: Configuration): Iterator[InternalRow] = - new Iterator[InternalRow] { + new Iterator[InternalRow] with Closeable { private val tableAvroSchema = new Schema.Parser().parse(tableState.tableAvroSchema) private val requiredAvroSchema = new Schema.Parser().parse(tableState.requiredAvroSchema) private val requiredFieldPosition = @@ -207,7 +237,8 @@ class HoodieMergeOnReadRDD(@transient sc: SparkContext, private val requiredDeserializer = HoodieAvroDeserializer(requiredAvroSchema, tableState.requiredStructSchema) private val recordBuilder = new GenericRecordBuilder(requiredAvroSchema) private val unsafeProjection = UnsafeProjection.create(tableState.requiredStructSchema) - private val logRecords = HoodieMergeOnReadRDD.scanLog(split, tableAvroSchema, config).getRecords + private var logScanner = HoodieMergeOnReadRDD.scanLog(split, tableAvroSchema, config) + private val logRecords = logScanner.getRecords private val logRecordsKeyIterator = logRecords.keySet().iterator().asScala private val keyToSkip = mutable.Set.empty[String] private val recordKeyPosition = if (recordKeyFieldOpt.isEmpty) HOODIE_RECORD_KEY_COL_POS else tableState.tableStructSchema.fieldIndex(recordKeyFieldOpt.get) @@ -276,6 +307,16 @@ class HoodieMergeOnReadRDD(@transient sc: SparkContext, override def next(): InternalRow = recordToLoad + override def close(): Unit = { + if (logScanner != null) { + try { + logScanner.close() + } finally { + logScanner = null + } + } + } + private def createRowWithRequiredSchema(row: InternalRow): InternalRow = { val rowToReturn = new SpecificInternalRow(tableState.requiredStructSchema) val posIterator = requiredFieldPosition.iterator