Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 41 additions & 34 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -949,8 +949,8 @@ impl PhysicalPlanner {
))
}
OpStruct::NativeScan(scan) => {
let data_schema = parse_message_type(&*scan.data_schema).unwrap();
let required_schema = parse_message_type(&*scan.required_schema).unwrap();
let data_schema = parse_message_type(&scan.data_schema).unwrap();
let required_schema = parse_message_type(&scan.required_schema).unwrap();

let data_schema_descriptor =
parquet::schema::types::SchemaDescriptor::new(Arc::new(data_schema));
Expand All @@ -968,16 +968,6 @@ impl PhysicalPlanner {
)
.unwrap(),
);
assert!(!required_schema_arrow.fields.is_empty());

let mut projection_vector: Vec<usize> =
Vec::with_capacity(required_schema_arrow.fields.len());
// TODO: could be faster with a hashmap rather than iterating over data_schema_arrow with index_of.
required_schema_arrow.fields.iter().for_each(|field| {
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
});

assert_eq!(projection_vector.len(), required_schema_arrow.fields.len());

// Convert the Spark expressions to Physical expressions
let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = scan
Expand All @@ -997,39 +987,56 @@ impl PhysicalPlanner {
))
});

let object_store_url = ObjectStoreUrl::local_filesystem();
let paths: Vec<Url> = scan
.path
.iter()
.map(|path| Url::parse(path).unwrap())
.collect();

let object_store = object_store::local::LocalFileSystem::new();
// register the object store with the runtime environment
let url = Url::try_from("file://").unwrap();
self.session_ctx
.runtime_env()
.register_object_store(&url, Arc::new(object_store));

let files: Vec<PartitionedFile> = paths
.iter()
.map(|path| PartitionedFile::from_path(path.path().to_string()).unwrap())
.collect();

// partition the files
// TODO really should partition the row groups

let mut file_groups = vec![vec![]; partition_count];
files.iter().enumerate().for_each(|(idx, file)| {
file_groups[idx % partition_count].push(file.clone());
// Generate file groups
let mut file_groups: Vec<Vec<PartitionedFile>> =
Vec::with_capacity(partition_count);
scan.file_partitions.iter().for_each(|partition| {
let mut files = Vec::with_capacity(partition.partitioned_file.len());
partition.partitioned_file.iter().for_each(|file| {
assert!(file.start + file.length <= file.file_size);
files.push(PartitionedFile::new_with_range(
Url::parse(file.file_path.as_ref())
.unwrap()
.path()
.to_string(),
file.file_size as u64,
file.start,
file.start + file.length,
));
});
file_groups.push(files);
});

let file_scan_config =
// TODO: I think we can remove partition_count in the future, but leave for testing.
assert_eq!(file_groups.len(), partition_count);

let object_store_url = ObjectStoreUrl::local_filesystem();
let mut file_scan_config =
FileScanConfig::new(object_store_url, Arc::clone(&data_schema_arrow))
.with_file_groups(file_groups)
.with_projection(Some(projection_vector));
.with_file_groups(file_groups);

// Check for projection, if so generate the vector and add to FileScanConfig.
if !required_schema_arrow.fields.is_empty() {
let mut projection_vector: Vec<usize> =
Vec::with_capacity(required_schema_arrow.fields.len());
// TODO: could be faster with a hashmap rather than iterating over data_schema_arrow with index_of.
required_schema_arrow.fields.iter().for_each(|field| {
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
});

assert_eq!(projection_vector.len(), required_schema_arrow.fields.len());
file_scan_config = file_scan_config.with_projection(Some(projection_vector));
}

let mut table_parquet_options = TableParquetOptions::new();
// TODO: Maybe these are configs?
table_parquet_options.global.pushdown_filters = true;
table_parquet_options.global.reorder_filters = true;

Expand All @@ -1041,7 +1048,7 @@ impl PhysicalPlanner {
}

let scan = builder.build();
return Ok((vec![], Arc::new(scan)));
Ok((vec![], Arc::new(scan)))
}
OpStruct::Scan(scan) => {
let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec();
Expand Down
21 changes: 17 additions & 4 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ message Operator {
}
}

message SparkPartitionedFile {
string file_path = 1;
int64 start = 2;
int64 length = 3;
int64 file_size = 4;
}

// This name and the one above are not great, but they correspond to the (unfortunate) Spark names.
// I prepended "Spark" since I think there's a name collision on the native side, but we can revisit.
message SparkFilePartition {
repeated SparkPartitionedFile partitioned_file = 1;
}

message Scan {
repeated spark.spark_expression.DataType fields = 1;
// The source of the scan (e.g. file scan, broadcast exchange, shuffle, etc). This
Expand All @@ -61,10 +74,10 @@ message NativeScan {
// is purely for informational purposes when viewing native query plans in
// debug mode.
string source = 2;
repeated string path = 3;
string required_schema = 4;
string data_schema = 5;
repeated spark.spark_expression.Expr data_filters = 6;
string required_schema = 3;
string data_schema = 4;
repeated spark.spark_expression.Expr data_filters = 5;
repeated SparkFilePartition file_partitions = 6;
}

message Projection {
Expand Down
21 changes: 18 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.datasources.FileScanRDD
import org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
Expand Down Expand Up @@ -2496,16 +2497,30 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
val dataFilters = scan.dataFilters.map(exprToProto(_, scan.output))
nativeScanBuilder.addAllDataFilters(dataFilters.map(_.get).asJava)

// Eventually we'll want to modify CometNativeScan to generate the file partitions
// for us without instantiating the RDD.
val file_partitions = scan.inputRDD.asInstanceOf[FileScanRDD].filePartitions;
file_partitions.foreach(partition => {
val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder()
partition.files.foreach(file => {
val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder()
fileBuilder
.setFilePath(file.pathUri.toString)
.setStart(file.start)
.setLength(file.length)
.setFileSize(file.fileSize)
partitionBuilder.addPartitionedFile(fileBuilder.build())
})
nativeScanBuilder.addFilePartitions(partitionBuilder.build())
})

val requiredSchemaParquet =
new SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema)
val dataSchemaParquet =
new SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema)

nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString)
nativeScanBuilder.setDataSchema(dataSchemaParquet.toString)
scan.relation.location.inputFiles.foreach { f =>
nativeScanBuilder.addPath(f)
}

Some(result.setNativeScan(nativeScanBuilder).build())

Expand Down
Loading