Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 40 additions & 56 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ use datafusion_physical_expr::LexOrdering;
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
use parquet::schema::parser::parse_message_type;
use std::cmp::max;
use std::{collections::HashMap, sync::Arc};
use url::Url;
Expand Down Expand Up @@ -950,50 +949,28 @@ impl PhysicalPlanner {
))
}
OpStruct::NativeScan(scan) => {
let data_schema = parse_message_type(&scan.data_schema).unwrap();
let required_schema = parse_message_type(&scan.required_schema).unwrap();

let data_schema_descriptor =
parquet::schema::types::SchemaDescriptor::new(Arc::new(data_schema));
let data_schema_arrow = Arc::new(
parquet::arrow::schema::parquet_to_arrow_schema(&data_schema_descriptor, None)
.unwrap(),
);

let required_schema_descriptor =
parquet::schema::types::SchemaDescriptor::new(Arc::new(required_schema));
let required_schema_arrow = Arc::new(
parquet::arrow::schema::parquet_to_arrow_schema(
&required_schema_descriptor,
None,
)
.unwrap(),
);

let partition_schema_arrow = scan
.partition_schema
let data_schema = convert_spark_types_to_arrow_schema(scan.data_schema.as_slice());
let required_schema: SchemaRef =
convert_spark_types_to_arrow_schema(scan.required_schema.as_slice());
let partition_schema: SchemaRef =
convert_spark_types_to_arrow_schema(scan.partition_schema.as_slice());
let projection_vector: Vec<usize> = scan
.projection_vector
.iter()
.map(to_arrow_datatype)
.collect_vec();
let partition_fields: Vec<_> = partition_schema_arrow
.iter()
.enumerate()
.map(|(idx, data_type)| {
Field::new(format!("part_{}", idx), data_type.clone(), true)
})
.map(|offset| *offset as usize)
.collect();

// Convert the Spark expressions to Physical expressions
let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = scan
.data_filters
.iter()
.map(|expr| self.create_expr(expr, Arc::clone(&required_schema_arrow)))
.map(|expr| self.create_expr(expr, Arc::clone(&required_schema)))
.collect();

// Create a conjunctive form of the vector because ParquetExecBuilder takes
// a single expression
let data_filters = data_filters?;
let test_data_filters = data_filters.clone().into_iter().reduce(|left, right| {
let cnf_data_filters = data_filters.clone().into_iter().reduce(|left, right| {
Arc::new(BinaryExpr::new(
left,
datafusion::logical_expr::Operator::And,
Expand Down Expand Up @@ -1064,29 +1041,21 @@ impl PhysicalPlanner {
assert_eq!(file_groups.len(), partition_count);

let object_store_url = ObjectStoreUrl::local_filesystem();
let partition_fields: Vec<Field> = partition_schema
.fields()
.iter()
.map(|field| {
Field::new(field.name(), field.data_type().clone(), field.is_nullable())
})
.collect_vec();
let mut file_scan_config =
FileScanConfig::new(object_store_url, Arc::clone(&data_schema_arrow))
FileScanConfig::new(object_store_url, Arc::clone(&data_schema))
.with_file_groups(file_groups)
.with_table_partition_cols(partition_fields);

// Check for projection, if so generate the vector and add to FileScanConfig.
let mut projection_vector: Vec<usize> =
Vec::with_capacity(required_schema_arrow.fields.len());
// TODO: could be faster with a hashmap rather than iterating over data_schema_arrow with index_of.
required_schema_arrow.fields.iter().for_each(|field| {
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
});

partition_schema_arrow
.iter()
.enumerate()
.for_each(|(idx, _)| {
projection_vector.push(idx + data_schema_arrow.fields.len());
});

assert_eq!(
projection_vector.len(),
required_schema_arrow.fields.len() + partition_schema_arrow.len()
required_schema.fields.len() + partition_schema.fields.len()
);
file_scan_config = file_scan_config.with_projection(Some(projection_vector));

Expand All @@ -1095,13 +1064,11 @@ impl PhysicalPlanner {
table_parquet_options.global.pushdown_filters = true;
table_parquet_options.global.reorder_filters = true;

let mut builder = ParquetExecBuilder::new(file_scan_config)
.with_table_parquet_options(table_parquet_options)
.with_schema_adapter_factory(
Arc::new(CometSchemaAdapterFactory::default()),
);
let mut builder = ParquetExecBuilder::new(file_scan_config)
.with_table_parquet_options(table_parquet_options)
.with_schema_adapter_factory(Arc::new(CometSchemaAdapterFactory::default()));

if let Some(filter) = test_data_filters {
if let Some(filter) = cnf_data_filters {
builder = builder.with_predicate(filter);
}

Expand Down Expand Up @@ -2309,6 +2276,23 @@ fn from_protobuf_eval_mode(value: i32) -> Result<EvalMode, prost::DecodeError> {
}
}

fn convert_spark_types_to_arrow_schema(
spark_types: &[spark_operator::SparkStructField],
) -> SchemaRef {
let arrow_fields = spark_types
.iter()
.map(|spark_type| {
Field::new(
String::clone(&spark_type.name),
to_arrow_datatype(spark_type.data_type.as_ref().unwrap()),
spark_type.nullable,
)
})
.collect_vec();
let arrow_schema: SchemaRef = Arc::new(Schema::new(arrow_fields));
arrow_schema
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
Expand Down
3 changes: 2 additions & 1 deletion native/core/src/execution/datafusion/schema_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ impl SchemaMapper for SchemaMapping {
EvalMode::Legacy,
"UTC",
false,
)?.into_array(batch_col.len())
)?
.into_array(batch_col.len())
// and if that works, return the field and column.
.map(|new_col| (new_col, table_field.clone()))
})
Expand Down
13 changes: 10 additions & 3 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ message SparkFilePartition {
repeated SparkPartitionedFile partitioned_file = 1;
}

message SparkStructField {
string name = 1;
spark.spark_expression.DataType data_type = 2;
bool nullable = 3;
}

message Scan {
repeated spark.spark_expression.DataType fields = 1;
// The source of the scan (e.g. file scan, broadcast exchange, shuffle, etc). This
Expand All @@ -75,11 +81,12 @@ message NativeScan {
// is purely for informational purposes when viewing native query plans in
// debug mode.
string source = 2;
string required_schema = 3;
string data_schema = 4;
repeated spark.spark_expression.DataType partition_schema = 5;
repeated SparkStructField required_schema = 3;
repeated SparkStructField data_schema = 4;
repeated SparkStructField partition_schema = 5;
repeated spark.spark_expression.Expr data_filters = 6;
repeated SparkFilePartition file_partitions = 7;
repeated int64 projection_vector = 8;
}

message Projection {
Expand Down
40 changes: 30 additions & 10 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD}
import org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter
import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
Expand Down Expand Up @@ -2520,18 +2519,28 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
case _ =>
}

val requiredSchemaParquet =
new SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema)
val dataSchemaParquet =
new SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema)
val partitionSchema = scan.relation.partitionSchema.fields.flatMap { field =>
serializeDataType(field.dataType)
}
val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields)
val requiredSchema = schema2Proto(scan.requiredSchema.fields)
val dataSchema = schema2Proto(scan.relation.dataSchema.fields)

val data_schema_idxs = scan.requiredSchema.fields.map(field => {
scan.relation.dataSchema.fieldIndex(field.name)
})
val partition_schema_idxs = Array
.range(
scan.relation.dataSchema.fields.length,
scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length)

val projection_vector = (data_schema_idxs ++ partition_schema_idxs).map(idx =>
idx.toLong.asInstanceOf[java.lang.Long])

nativeScanBuilder.addAllProjectionVector(projection_vector.toIterable.asJava)

// In `CometScanRule`, we ensure partitionSchema is supported.
assert(partitionSchema.length == scan.relation.partitionSchema.fields.length)

nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString)
nativeScanBuilder.setDataSchema(dataSchemaParquet.toString)
nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)

Some(result.setNativeScan(nativeScanBuilder).build())
Expand Down Expand Up @@ -3198,6 +3207,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
true
}

private def schema2Proto(
fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField] = {
val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder()
fields.map(field => {
fieldBuilder.setName(field.name)
fieldBuilder.setDataType(serializeDataType(field.dataType).get)
fieldBuilder.setNullable(field.nullable)
fieldBuilder.build()
})
}

private def partition2Proto(
partition: FilePartition,
nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,
Expand Down
Loading