diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 8cd301612b..babc686960 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -969,6 +969,19 @@ impl PhysicalPlanner { .unwrap(), ); + let partition_schema_arrow = scan + .partition_schema + .iter() + .map(to_arrow_datatype) + .collect_vec(); + let partition_fields: Vec<_> = partition_schema_arrow + .iter() + .enumerate() + .map(|(idx, data_type)| { + Field::new(format!("part_{}", idx), data_type.clone(), true) + }) + .collect(); + // Convert the Spark expressions to Physical expressions let data_filters: Result>, ExecutionError> = scan .data_filters @@ -997,11 +1010,12 @@ impl PhysicalPlanner { // Generate file groups let mut file_groups: Vec> = Vec::with_capacity(partition_count); - scan.file_partitions.iter().for_each(|partition| { + scan.file_partitions.iter().try_for_each(|partition| { let mut files = Vec::with_capacity(partition.partitioned_file.len()); - partition.partitioned_file.iter().for_each(|file| { + partition.partitioned_file.iter().try_for_each(|file| { assert!(file.start + file.length <= file.file_size); - files.push(PartitionedFile::new_with_range( + + let mut partitioned_file = PartitionedFile::new_with_range( Url::parse(file.file_path.as_ref()) .unwrap() .path() @@ -1009,10 +1023,41 @@ impl PhysicalPlanner { file.file_size as u64, file.start, file.start + file.length, - )); - }); + ); + + // Process partition values + // Create an empty input schema for partition values because they are all literals. + let empty_schema = Arc::new(Schema::empty()); + let partition_values: Result, _> = file + .partition_values + .iter() + .map(|partition_value| { + let literal = self.create_expr( + partition_value, + Arc::::clone(&empty_schema), + )?; + literal + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ExecutionError::GeneralError( + "Expected literal of partition value".to_string(), + ) + }) + .map(|literal| literal.value().clone()) + }) + .collect(); + let partition_values = partition_values?; + + partitioned_file.partition_values = partition_values; + + files.push(partitioned_file); + Ok::<(), ExecutionError>(()) + })?; + file_groups.push(files); - }); + Ok::<(), ExecutionError>(()) + })?; // TODO: I think we can remove partition_count in the future, but leave for testing. assert_eq!(file_groups.len(), partition_count); @@ -1020,7 +1065,8 @@ impl PhysicalPlanner { let object_store_url = ObjectStoreUrl::local_filesystem(); let mut file_scan_config = FileScanConfig::new(object_store_url, Arc::clone(&data_schema_arrow)) - .with_file_groups(file_groups); + .with_file_groups(file_groups) + .with_table_partition_cols(partition_fields); // Check for projection, if so generate the vector and add to FileScanConfig. let mut projection_vector: Vec = @@ -1030,7 +1076,17 @@ impl PhysicalPlanner { projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap()); }); - assert_eq!(projection_vector.len(), required_schema_arrow.fields.len()); + partition_schema_arrow + .iter() + .enumerate() + .for_each(|(idx, _)| { + projection_vector.push(idx + data_schema_arrow.fields.len()); + }); + + assert_eq!( + projection_vector.len(), + required_schema_arrow.fields.len() + partition_schema_arrow.len() + ); file_scan_config = file_scan_config.with_projection(Some(projection_vector)); let mut table_parquet_options = TableParquetOptions::new(); diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index f50389dbe5..5e8a80f999 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -52,6 +52,7 @@ message SparkPartitionedFile { int64 start = 2; int64 length = 3; int64 file_size = 4; + repeated spark.spark_expression.Expr partition_values = 5; } // This name and the one above are not great, but they correspond to the (unfortunate) Spark names. @@ -76,8 +77,9 @@ message NativeScan { string source = 2; string required_schema = 3; string data_schema = 4; - repeated spark.spark_expression.Expr data_filters = 5; - repeated SparkFilePartition file_partitions = 6; + repeated spark.spark_expression.DataType partition_schema = 5; + repeated spark.spark_expression.Expr data_filters = 6; + repeated SparkFilePartition file_partitions = 7; } message Projection { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b8a780e608..7b978c9860 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, Normalize import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometNativeScanExec, CometScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ @@ -2507,12 +2507,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim partitions.foreach(p => { val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions inputPartitions.foreach(partition => { - partition2Proto(partition.asInstanceOf[FilePartition], nativeScanBuilder) + partition2Proto( + partition.asInstanceOf[FilePartition], + nativeScanBuilder, + scan.relation.partitionSchema) }) }) case rdd: FileScanRDD => rdd.filePartitions.foreach(partition => { - partition2Proto(partition, nativeScanBuilder) + partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema) }) case _ => } @@ -2521,9 +2524,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim new SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema) val dataSchemaParquet = new SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema) + val partitionSchema = scan.relation.partitionSchema.fields.flatMap { field => + serializeDataType(field.dataType) + } + // In `CometScanRule`, we ensure partitionSchema is supported. + assert(partitionSchema.length == scan.relation.partitionSchema.fields.length) nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString) nativeScanBuilder.setDataSchema(dataSchemaParquet.toString) + nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava) Some(result.setNativeScan(nativeScanBuilder).build()) @@ -3191,10 +3200,27 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim private def partition2Proto( partition: FilePartition, - nativeScanBuilder: OperatorOuterClass.NativeScan.Builder): Unit = { + nativeScanBuilder: OperatorOuterClass.NativeScan.Builder, + partitionSchema: StructType): Unit = { val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder() partition.files.foreach(file => { + // Process the partition values + val partitionValues = file.partitionValues + assert(partitionValues.numFields == partitionSchema.length) + val partitionVals = + partitionValues.toSeq(partitionSchema).zipWithIndex.map { case (value, i) => + val attr = partitionSchema(i) + val valueProto = exprToProto(Literal(value, attr.dataType), Seq.empty) + // In `CometScanRule`, we have already checked that all partition values are + // supported. So, we can safely use `get` here. + assert( + valueProto.isDefined, + s"Unsupported partition value: $value, type: ${attr.dataType}") + valueProto.get + } + val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder() + partitionVals.foreach(fileBuilder.addPartitionValues) fileBuilder .setFilePath(file.pathUri.toString) .setStart(file.start)