diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index cce48204ba..f9a2546693 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -35,8 +35,9 @@ import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} -import org.apache.spark.sql.execution.datasources.FileScanRDD +import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD} import org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec @@ -2497,22 +2498,22 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val dataFilters = scan.dataFilters.map(exprToProto(_, scan.output)) nativeScanBuilder.addAllDataFilters(dataFilters.map(_.get).asJava) - // Eventually we'll want to modify CometNativeScan to generate the file partitions - // for us without instantiating the RDD. - val file_partitions = scan.inputRDD.asInstanceOf[FileScanRDD].filePartitions; - file_partitions.foreach(partition => { - val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder() - partition.files.foreach(file => { - val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder() - fileBuilder - .setFilePath(file.pathUri.toString) - .setStart(file.start) - .setLength(file.length) - .setFileSize(file.fileSize) - partitionBuilder.addPartitionedFile(fileBuilder.build()) - }) - nativeScanBuilder.addFilePartitions(partitionBuilder.build()) - }) + // TODO: modify CometNativeScan to generate the file partitions without instantiating RDD. + scan.inputRDD match { + case rdd: DataSourceRDD => + val partitions = rdd.partitions + partitions.foreach(p => { + val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions + inputPartitions.foreach(partition => { + partition2Proto(partition.asInstanceOf[FilePartition], nativeScanBuilder) + }) + }) + case rdd: FileScanRDD => + rdd.filePartitions.foreach(partition => { + partition2Proto(partition, nativeScanBuilder) + }) + case _ => + } val requiredSchemaParquet = new SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema) @@ -3185,4 +3186,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim true } + + private def partition2Proto( + partition: FilePartition, + nativeScanBuilder: OperatorOuterClass.NativeScan.Builder): Unit = { + val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder() + partition.files.foreach(file => { + val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder() + fileBuilder + .setFilePath(file.pathUri.toString) + .setStart(file.start) + .setLength(file.length) + .setFileSize(file.fileSize) + partitionBuilder.addPartitionedFile(fileBuilder.build()) + }) + nativeScanBuilder.addFilePartitions(partitionBuilder.build()) + } }