diff --git a/common/src/main/java/org/apache/comet/parquet/Native.java b/common/src/main/java/org/apache/comet/parquet/Native.java index 1ed01d326e..b33ec60db6 100644 --- a/common/src/main/java/org/apache/comet/parquet/Native.java +++ b/common/src/main/java/org/apache/comet/parquet/Native.java @@ -246,15 +246,10 @@ public static native void setPageV2( * @param filePath * @param start * @param length - * @param required_columns array of names of fields to read * @return a handle to the record batch reader, used in subsequent calls. */ public static native long initRecordBatchReader( - String filePath, long start, long length, Object[] required_columns); - - public static native int numRowGroups(long handle); - - public static native long numTotalRows(long handle); + String filePath, long fileSize, long start, long length, byte[] requiredSchema); // arrow native version of read batch /** diff --git a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java index 3ac55ba4d9..8461bb506e 100644 --- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java @@ -19,11 +19,13 @@ package org.apache.comet.parquet; +import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URISyntaxException; +import java.nio.channels.Channels; import java.util.*; import scala.Option; @@ -36,6 +38,9 @@ import org.apache.arrow.c.CometSchemaImporter; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapreduce.InputSplit; @@ -52,6 +57,7 @@ import org.apache.spark.TaskContext$; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.comet.CometArrowUtils; import org.apache.spark.sql.comet.parquet.CometParquetReadSupport; import org.apache.spark.sql.execution.datasources.PartitionedFile; import org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter; @@ -99,7 +105,6 @@ public class NativeBatchReader extends RecordReader impleme private PartitionedFile file; private final Map metrics; - private long rowsRead; private StructType sparkSchema; private MessageType requestedSchema; private CometVector[] vectors; @@ -111,9 +116,6 @@ public class NativeBatchReader extends RecordReader impleme private boolean isInitialized; private ParquetMetadata footer; - /** The total number of rows across all row groups of the input split. */ - private long totalRowCount; - /** * Whether the native scan should always return decimal represented by 128 bits, regardless of its * precision. Normally, this should be true if native execution is enabled, since Arrow compute @@ -224,6 +226,7 @@ public void init() throws URISyntaxException, IOException { long start = file.start(); long length = file.length(); String filePath = file.filePath().toString(); + long fileSize = file.fileSize(); requestedSchema = footer.getFileMetaData().getSchema(); MessageType fileSchema = requestedSchema; @@ -254,6 +257,13 @@ public void init() throws URISyntaxException, IOException { } } ////// End get requested schema + String timeZoneId = conf.get("spark.sql.session.timeZone"); + Schema arrowSchema = CometArrowUtils.toArrowSchema(sparkSchema, timeZoneId); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out)); + MessageSerializer.serialize(writeChannel, arrowSchema); + byte[] serializedRequestedArrowSchema = out.toByteArray(); + //// Create Column readers List columns = requestedSchema.getColumns(); int numColumns = columns.size(); @@ -334,13 +344,9 @@ public void init() throws URISyntaxException, IOException { } } - // TODO: (ARROW NATIVE) Use a ProjectionMask here ? - ArrayList requiredColumns = new ArrayList<>(); - for (Type col : requestedSchema.asGroupType().getFields()) { - requiredColumns.add(col.getName()); - } - this.handle = Native.initRecordBatchReader(filePath, start, length, requiredColumns.toArray()); - totalRowCount = Native.numRowGroups(handle); + this.handle = + Native.initRecordBatchReader( + filePath, fileSize, start, length, serializedRequestedArrowSchema); isInitialized = true; } @@ -375,7 +381,7 @@ public ColumnarBatch getCurrentValue() { @Override public float getProgress() { - return (float) rowsRead / totalRowCount; + return 0; } /** @@ -395,7 +401,7 @@ public ColumnarBatch currentBatch() { public boolean nextBatch() throws IOException { Preconditions.checkState(isInitialized, "init() should be called first!"); - if (rowsRead >= totalRowCount) return false; + // if (rowsRead >= totalRowCount) return false; int batchSize; try { @@ -432,7 +438,6 @@ public boolean nextBatch() throws IOException { } currentBatch.setNumRows(batchSize); - rowsRead += batchSize; return true; } @@ -457,6 +462,9 @@ private int loadNextBatch() throws Throwable { long startNs = System.nanoTime(); int batchSize = Native.readNextRecordBatch(this.handle); + if (batchSize == 0) { + return batchSize; + } if (importer != null) importer.close(); importer = new CometSchemaImporter(ALLOCATOR); diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 275114a11c..fabdd30c4a 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -75,7 +75,7 @@ object CometConf extends ShimCometConf { "that to enable native vectorized execution, both this config and " + "'spark.comet.exec.enabled' need to be enabled.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val COMET_FULL_NATIVE_SCAN_ENABLED: ConfigEntry[Boolean] = conf( "spark.comet.native.scan.enabled") @@ -85,15 +85,15 @@ object CometConf extends ShimCometConf { "read supported data sources (currently only Parquet is supported natively)." + " By default, this config is true.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) - val COMET_NATIVE_ARROW_SCAN_ENABLED: ConfigEntry[Boolean] = conf( + val COMET_NATIVE_RECORDBATCH_READER_ENABLED: ConfigEntry[Boolean] = conf( "spark.comet.native.arrow.scan.enabled") .internal() .doc( - "Whether to enable the fully native arrow based scan. When this is turned on, Spark will " + - "use Comet to read Parquet files natively via the Arrow based Parquet reader." + - " By default, this config is false.") + "Whether to enable the fully native datafusion based column reader. When this is turned on," + + " Spark will use Comet to read Parquet files natively via the Datafusion based Parquet" + + " reader. By default, this config is false.") .booleanConf .createWithDefault(false) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala b/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala new file mode 100644 index 0000000000..2f4f55fc0b --- /dev/null +++ b/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.JavaConverters._ + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.complex.MapVector +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +object CometArrowUtils { + + val rootAllocator = new RootAllocator(Long.MaxValue) + + // todo: support more types. + + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ + def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case StringType => ArrowType.Utf8.INSTANCE + case BinaryType => ArrowType.Binary.INSTANCE + case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType if timeZoneId == null => + throw new IllegalStateException("Missing timezoneId where it is mandatory.") + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + case TimestampNTZType => + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case NullType => ArrowType.Null.INSTANCE + case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) + case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) + case _ => + throw new IllegalArgumentException() + } + + def fromArrowType(dt: ArrowType): DataType = dt match { + case ArrowType.Bool.INSTANCE => BooleanType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.SINGLE => + FloatType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.DOUBLE => + DoubleType + case ArrowType.Utf8.INSTANCE => StringType + case ArrowType.Binary.INSTANCE => BinaryType + case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType + case ts: ArrowType.Timestamp + if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => + TimestampNTZType + case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType + case ArrowType.Null.INSTANCE => NullType + case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => + YearMonthIntervalType() + case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() + case _ => throw new IllegalArgumentException() + // throw QueryExecutionErrors.unsupportedArrowTypeError(dt) + } + + /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ + def toArrowField(name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { + dt match { + case ArrayType(elementType, containsNull) => + val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) + new Field( + name, + fieldType, + Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) + case StructType(fields) => + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + new Field( + name, + fieldType, + fields + .map { field => + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) + } + .toSeq + .asJava) + case MapType(keyType, valueType, valueContainsNull) => + val mapType = new FieldType(nullable, new ArrowType.Map(false), null) + // Note: Map Type struct can not be null, Struct Type key field can not be null + new Field( + name, + mapType, + Seq( + toArrowField( + MapVector.DATA_VECTOR_NAME, + new StructType() + .add(MapVector.KEY_NAME, keyType, nullable = false) + .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), + nullable = false, + timeZoneId)).asJava) + case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, nullable, timeZoneId) + case dataType => + val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null) + new Field(name, fieldType, Seq.empty[Field].asJava) + } + } + + def fromArrowField(field: Field): DataType = { + field.getType match { + case _: ArrowType.Map => + val elementField = field.getChildren.get(0) + val keyType = fromArrowField(elementField.getChildren.get(0)) + val valueType = fromArrowField(elementField.getChildren.get(1)) + MapType(keyType, valueType, elementField.getChildren.get(1).isNullable) + case ArrowType.List.INSTANCE => + val elementField = field.getChildren().get(0) + val elementType = fromArrowField(elementField) + ArrayType(elementType, containsNull = elementField.isNullable) + case ArrowType.Struct.INSTANCE => + val fields = field.getChildren().asScala.map { child => + val dt = fromArrowField(child) + StructField(child.getName, dt, child.isNullable) + } + StructType(fields.toArray) + case arrowType => fromArrowType(arrowType) + } + } + + /** + * Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType + */ + def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { + new Schema(schema.map { field => + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) + }.asJava) + } + + def fromArrowSchema(schema: Schema): StructType = { + StructType(schema.getFields.asScala.map { field => + val dt = fromArrowField(field) + StructField(field.getName, dt, field.isNullable) + }.toArray) + } + + /** Return Map with conf settings to be used in ArrowPythonRunner */ + def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { + val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + val pandasColsByName = Seq( + SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> + conf.pandasGroupedMapAssignColumnsByName.toString) + val arrowSafeTypeCheck = Seq( + SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key -> + conf.arrowSafeTypeConversion.toString) + Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*) + } + +} diff --git a/native/Cargo.lock b/native/Cargo.lock index c3a664ff3e..27e9726836 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -898,6 +898,7 @@ dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", + "arrow-ipc", "arrow-schema", "assertables", "async-trait", @@ -913,6 +914,7 @@ dependencies = [ "datafusion-expr", "datafusion-functions-nested", "datafusion-physical-expr", + "flatbuffers", "flate2", "futures", "half", diff --git a/native/Cargo.toml b/native/Cargo.toml index 4b89231c78..b78c1d68f0 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -37,7 +37,9 @@ arrow = { version = "53.2.0", features = ["prettyprint", "ffi", "chrono-tz"] } arrow-array = { version = "53.2.0" } arrow-buffer = { version = "53.2.0" } arrow-data = { version = "53.2.0" } +arrow-ipc = { version = "53.2.0" } arrow-schema = { version = "53.2.0" } +flatbuffers = { version = "24.3.25" } parquet = { version = "53.2.0", default-features = false, features = ["experimental"] } datafusion-common = { version = "43.0.0" } datafusion = { version = "43.0.0", default-features = false, features = ["unicode_expressions", "crypto_expressions", "parquet"] } diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 8d30b38cf1..35035ff353 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -40,6 +40,8 @@ arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } +arrow-ipc = { workspace = true } +flatbuffers = { workspace = true } parquet = { workspace = true, default-features = false, features = ["experimental"] } half = { version = "2.4.1", default-features = false } futures = "0.3.28" diff --git a/native/core/src/execution/datafusion/mod.rs b/native/core/src/execution/datafusion/mod.rs index fb9c8829c0..af32b4be13 100644 --- a/native/core/src/execution/datafusion/mod.rs +++ b/native/core/src/execution/datafusion/mod.rs @@ -20,6 +20,6 @@ pub mod expressions; mod operators; pub mod planner; -mod schema_adapter; +pub(crate) mod schema_adapter; pub mod shuffle_writer; mod util; diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index afca606624..c234b6f7b4 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -23,7 +23,7 @@ pub use mutable_vector::*; pub mod util; pub mod read; -use std::fs::File; +use std::task::Poll; use std::{boxed::Box, ptr::NonNull, sync::Arc}; use crate::errors::{try_unwrap_or_throw, CometError}; @@ -42,17 +42,21 @@ use jni::{ use crate::execution::operators::ExecutionError; use crate::execution::utils::SparkArrowConvert; +use crate::parquet::data_type::AsBytes; use arrow::buffer::{Buffer, MutableBuffer}; use arrow_array::{Array, RecordBatch}; -use jni::objects::{ - JBooleanArray, JLongArray, JObjectArray, JPrimitiveArray, JString, ReleaseMode, -}; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder; +use datafusion::datasource::physical_plan::FileScanConfig; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_common::config::TableParquetOptions; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use futures::{poll, StreamExt}; +use jni::objects::{JBooleanArray, JByteArray, JLongArray, JPrimitiveArray, JString, ReleaseMode}; use jni::sys::jstring; -use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder}; -use parquet::arrow::ProjectionMask; +use parquet::arrow::arrow_reader::ParquetRecordBatchReader; use read::ColumnReader; -use url::Url; -use util::jni::{convert_column_descriptor, convert_encoding}; +use util::jni::{convert_column_descriptor, convert_encoding, deserialize_schema, get_file_path}; use self::util::jni::TypePromotionInfo; @@ -600,11 +604,11 @@ enum ParquetReaderState { } /// Parquet read context maintained across multiple JNI calls. struct BatchContext { - batch_reader: ParquetRecordBatchReader, + runtime: tokio::runtime::Runtime, + batch_stream: Option, + batch_reader: Option, current_batch: Option, reader_state: ParquetReaderState, - num_row_groups: i32, - total_rows: i64, } #[inline] @@ -616,10 +620,12 @@ fn get_batch_context<'a>(handle: jlong) -> Result<&'a mut BatchContext, CometErr } } +/* #[inline] fn get_batch_reader<'a>(handle: jlong) -> Result<&'a mut ParquetRecordBatchReader, CometError> { - Ok(&mut get_batch_context(handle)?.batch_reader) + Ok(&mut get_batch_context(handle)?.batch_reader.unwrap()) } +*/ /// # Safety /// This function is inherently unsafe since it deals with raw pointers passed from JNI. @@ -628,118 +634,80 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat e: JNIEnv, _jclass: JClass, file_path: jstring, + file_size: jlong, start: jlong, length: jlong, - required_columns: jobjectArray, + required_schema: jbyteArray, ) -> jlong { try_unwrap_or_throw(&e, |mut env| unsafe { let path: String = env .get_string(&JString::from_raw(file_path)) .unwrap() .into(); - //TODO: (ARROW NATIVE) - this works only for 'file://' urls - let path = Url::parse(path.as_ref()).unwrap().to_file_path().unwrap(); - let file = File::open(path).unwrap(); - - // Create a async parquet reader builder with batch_size. - // batch_size is the number of rows to read up to buffer once from pages, defaults to 1024 - // TODO: (ARROW NATIVE) Use async reader ParquetRecordBatchStreamBuilder - let mut builder = ParquetRecordBatchReaderBuilder::try_new(file) - .unwrap() - .with_batch_size(8192); // TODO: (ARROW NATIVE) Use batch size configured in JVM - - let num_row_groups; - let mut total_rows: i64 = 0; - //TODO: (ARROW NATIVE) if we can get the ParquetMetadata serialized, we need not do this. - { - let metadata = builder.metadata(); - - let mut columns_to_read: Vec = Vec::new(); - let columns_to_read_array = JObjectArray::from_raw(required_columns); - let array_len = env.get_array_length(&columns_to_read_array)?; - let mut required_columns: Vec = Vec::new(); - for i in 0..array_len { - let p: JString = env - .get_object_array_element(&columns_to_read_array, i)? - .into(); - required_columns.push(env.get_string(&p)?.into()); - } - for (i, col) in metadata - .file_metadata() - .schema_descr() - .columns() - .iter() - .enumerate() - { - for required in required_columns.iter() { - if col.name().to_uppercase().eq(&required.to_uppercase()) { - columns_to_read.push(i); - break; - } - } - } - //TODO: (ARROW NATIVE) make this work for complex types (especially deeply nested structs) - let mask = - ProjectionMask::leaves(metadata.file_metadata().schema_descr(), columns_to_read); - // Set projection mask to read only root columns 1 and 2. - - let mut row_groups_to_read: Vec = Vec::new(); - // get row groups - - for (i, rg) in metadata.row_groups().iter().enumerate() { - let rg_start = rg.file_offset().unwrap(); - let rg_end = rg_start + rg.compressed_size(); - if rg_start >= start && rg_end <= start + length { - row_groups_to_read.push(i); - total_rows += rg.num_rows(); - } - } - num_row_groups = row_groups_to_read.len(); - builder = builder - .with_projection(mask) - .with_row_groups(row_groups_to_read.clone()) - } - - // Build a sync parquet reader. - let batch_reader = builder.build().unwrap(); + let batch_stream: Option; + let batch_reader: Option = None; + // TODO: (ARROW NATIVE) Use the common global runtime + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()?; + + // EXPERIMENTAL - BEGIN + //TODO: Need an execution context and a spark plan equivalent so that we can reuse + // code from jni_api.rs + let (object_store_url, object_store_path) = get_file_path(path.clone()).unwrap(); + // TODO: (ARROW NATIVE) - Remove code duplication between this and POC 1 + // copy the input on-heap buffer to native + let required_schema_array = JByteArray::from_raw(required_schema); + let required_schema_buffer = env.convert_byte_array(&required_schema_array)?; + let required_schema_arrow = deserialize_schema(required_schema_buffer.as_bytes())?; + let mut partitioned_file = PartitionedFile::new_with_range( + String::new(), // Dummy file path. We will override this with our path so that url encoding does not occur + file_size as u64, + start, + start + length, + ); + partitioned_file.object_meta.location = object_store_path; + // We build the file scan config with the *required* schema so that the reader knows + // the output schema we want + let file_scan_config = FileScanConfig::new(object_store_url, Arc::new(required_schema_arrow)) + .with_file(partitioned_file) + // TODO: (ARROW NATIVE) - do partition columns in native + // - will need partition schema and partition values to do so + // .with_table_partition_cols(partition_fields) + ; + let mut table_parquet_options = TableParquetOptions::new(); + // TODO: Maybe these are configs? + table_parquet_options.global.pushdown_filters = true; + table_parquet_options.global.reorder_filters = true; + + let builder2 = ParquetExecBuilder::new(file_scan_config) + .with_table_parquet_options(table_parquet_options) + .with_schema_adapter_factory(Arc::new( + crate::execution::datafusion::schema_adapter::CometSchemaAdapterFactory::default(), + )); + + //TODO: (ARROW NATIVE) - predicate pushdown?? + // builder = builder.with_predicate(filter); + + let scan = builder2.build(); + let ctx = TaskContext::default(); + let partition_index: usize = 0; + batch_stream = Some(scan.execute(partition_index, Arc::new(ctx))?); + + // EXPERIMENTAL - END let ctx = BatchContext { + runtime, + batch_stream, batch_reader, current_batch: None, reader_state: ParquetReaderState::Init, - num_row_groups: num_row_groups as i32, - total_rows, }; let res = Box::new(ctx); Ok(Box::into_raw(res) as i64) }) } -#[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_numRowGroups( - e: JNIEnv, - _jclass: JClass, - handle: jlong, -) -> jint { - try_unwrap_or_throw(&e, |_env| { - let context = get_batch_context(handle)?; - // Read data - Ok(context.num_row_groups) - }) as jint -} - -#[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_numTotalRows( - e: JNIEnv, - _jclass: JClass, - handle: jlong, -) -> jlong { - try_unwrap_or_throw(&e, |_env| { - let context = get_batch_context(handle)?; - // Read data - Ok(context.total_rows) - }) as jlong -} - #[no_mangle] pub extern "system" fn Java_org_apache_comet_parquet_Native_readNextRecordBatch( e: JNIEnv, @@ -748,21 +716,39 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_readNextRecordBatch( ) -> jint { try_unwrap_or_throw(&e, |_env| { let context = get_batch_context(handle)?; - let batch_reader = &mut context.batch_reader; - // Read data let mut rows_read: i32 = 0; - let batch = batch_reader.next(); - - match batch { - Some(record_batch) => { - let batch = record_batch?; - rows_read = batch.num_rows() as i32; - context.current_batch = Some(batch); - context.reader_state = ParquetReaderState::Reading; - } - None => { - context.current_batch = None; - context.reader_state = ParquetReaderState::Complete; + let batch_stream = context.batch_stream.as_mut().unwrap(); + let runtime = &context.runtime; + + // let mut stream = batch_stream.as_mut(); + loop { + let next_item = batch_stream.next(); + let poll_batch: Poll>> = + runtime.block_on(async { poll!(next_item) }); + + match poll_batch { + Poll::Ready(Some(batch)) => { + let batch = batch?; + rows_read = batch.num_rows() as i32; + context.current_batch = Some(batch); + context.reader_state = ParquetReaderState::Reading; + break; + } + Poll::Ready(None) => { + // EOF + + // TODO: (ARROW NATIVE) We can update metrics here + // crate::execution::jni_api::update_metrics(&mut env, exec_context)?; + + context.current_batch = None; + context.reader_state = ParquetReaderState::Complete; + break; + } + Poll::Pending => { + // TODO: (ARROW NATIVE): Just keeping polling?? + // Ideally we want to yield to avoid consuming CPU while blocked on IO ?? + continue; + } } } Ok(rows_read) diff --git a/native/core/src/parquet/util/jni.rs b/native/core/src/parquet/util/jni.rs index b61fbeab32..596277b379 100644 --- a/native/core/src/parquet/util/jni.rs +++ b/native/core/src/parquet/util/jni.rs @@ -24,11 +24,17 @@ use jni::{ JNIEnv, }; +use crate::execution::sort::RdxSort; +use arrow::error::ArrowError; +use arrow::ipc::reader::StreamReader; +use datafusion_execution::object_store::ObjectStoreUrl; +use object_store::path::Path; use parquet::{ basic::{Encoding, LogicalType, TimeUnit, Type as PhysicalType}, format::{MicroSeconds, MilliSeconds, NanoSeconds}, schema::types::{ColumnDescriptor, ColumnPath, PrimitiveTypeBuilder}, }; +use url::{ParseError, Url}; /// Convert primitives from Spark side into a `ColumnDescriptor`. #[allow(clippy::too_many_arguments)] @@ -198,3 +204,52 @@ fn fix_type_length(t: &PhysicalType, type_length: i32) -> i32 { _ => type_length, } } + +pub fn deserialize_schema(ipc_bytes: &[u8]) -> Result { + let reader = StreamReader::try_new(std::io::Cursor::new(ipc_bytes), None)?; + let schema = reader.schema().as_ref().clone(); + Ok(schema) +} + +// parses the url and returns a tuple of the scheme and object store path +pub fn get_file_path(url_: String) -> Result<(ObjectStoreUrl, Path), ParseError> { + // we define origin of a url as scheme + "://" + authority + ["/" + bucket] + let url = Url::parse(url_.as_ref()).unwrap(); + let mut object_store_origin = url.scheme().to_owned(); + let mut object_store_path = Path::from_url_path(url.path()).unwrap(); + if object_store_origin == "s3a" { + object_store_origin = "s3".to_string(); + object_store_origin.push_str("://"); + object_store_origin.push_str(url.authority()); + object_store_origin.push('/'); + let path_splits = url.path_segments().map(|c| c.collect::>()).unwrap(); + object_store_origin.push_str(path_splits.first().unwrap()); + let new_path = path_splits[1..path_splits.len() - 1].join("/"); + //TODO: (ARROW NATIVE) check the use of unwrap here + object_store_path = Path::from_url_path(new_path.clone().as_str()).unwrap(); + } else { + object_store_origin.push_str("://"); + object_store_origin.push_str(url.authority()); + object_store_origin.push('/'); + } + Ok(( + ObjectStoreUrl::parse(object_store_origin).unwrap(), + object_store_path, + )) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_file_path() { + let inp = "file:///comet/spark-warehouse/t1/part1=2019-01-01%2011%253A11%253A11/part-00000-84d7ed74-8f28-456c-9270-f45376eea144.c000.snappy.parquet"; + let expected = "comet/spark-warehouse/t1/part1=2019-01-01 11%3A11%3A11/part-00000-84d7ed74-8f28-456c-9270-f45376eea144.c000.snappy.parquet"; + + if let Ok((_obj_store_url, path)) = get_file_path(inp.to_string()) { + let actual = path.to_string(); + assert_eq!(actual, expected); + } + } +} diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index a6d139716f..95fde37351 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -571,7 +571,7 @@ impl SparkCastOptions { eval_mode, timezone: timezone.to_string(), allow_incompat, - is_adapting_schema: false + is_adapting_schema: false, } } @@ -583,7 +583,6 @@ impl SparkCastOptions { is_adapting_schema: false, } } - } /// Spark-compatible cast implementation. Defers to DataFusion's cast where that is known @@ -2309,8 +2308,7 @@ mod tests { #[test] fn test_cast_invalid_timezone() { let timestamps: PrimitiveArray = vec![i64::MAX].into(); - let cast_options = - SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false); + let cast_options = SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false); let result = cast_array( Arc::new(timestamps.with_timezone("Europe/Copenhagen")), &DataType::Date32, @@ -2401,9 +2399,7 @@ mod tests { let cast_array = spark_cast( ColumnarValue::Array(c), &DataType::Struct(fields), - &SparkCastOptions::new(EvalMode::Legacy, - "UTC", - false) + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), ) .unwrap(); if let ColumnarValue::Array(cast_array) = cast_array { diff --git a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala index eb524af906..e4235495f2 100644 --- a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala +++ b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala @@ -41,7 +41,7 @@ trait DataTypeSupport { case t: DataType if t.typeName == "timestamp_ntz" => true case _: StructType if CometConf.COMET_FULL_NATIVE_SCAN_ENABLED - .get() || CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.get() => + .get() || CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.get() => true case _ => false } diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala index 4c96bef4e9..c142abb5cd 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala @@ -100,7 +100,7 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with // Comet specific configurations val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf) - val nativeArrowReaderEnabled = CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.get(sqlConf) + val nativeArrowReaderEnabled = CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.get(sqlConf) (file: PartitionedFile) => { val sharedConf = broadcastedHadoopConf.value.value diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 99ed5d3cb2..e997c5bfd8 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -80,8 +80,8 @@ abstract class CometTestBase conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") conf.set(CometConf.COMET_SPARK_TO_ARROW_ENABLED.key, "true") conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true") - conf.set(CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key, "true") - conf.set(CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.key, "false") + conf.set(CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key, "false") + conf.set(CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.key, "true") conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") conf.set(CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key, "true") conf diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index a553e61c78..080655fe29 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -89,9 +89,13 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa actualSimplifiedPlan: String, actualExplain: String): Boolean = { val simplifiedFile = new File(dir, "simplified.txt") - val expectedSimplified = FileUtils.readFileToString(simplifiedFile, StandardCharsets.UTF_8) - lazy val explainFile = new File(dir, "explain.txt") - lazy val expectedExplain = FileUtils.readFileToString(explainFile, StandardCharsets.UTF_8) + var expectedSimplified = FileUtils.readFileToString(simplifiedFile, StandardCharsets.UTF_8) + val explainFile = new File(dir, "explain.txt") + var expectedExplain = FileUtils.readFileToString(explainFile, StandardCharsets.UTF_8) + if (!CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.get()) { + expectedExplain = expectedExplain.replace("CometNativeScan", "CometScan") + expectedSimplified = expectedSimplified.replace("CometNativeScan", "CometScan") + } expectedSimplified == actualSimplifiedPlan && expectedExplain == actualExplain } @@ -259,6 +263,9 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa // Disable char/varchar read-side handling for better performance. withSQLConf( CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true", + CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key -> "false", + CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "false", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", @@ -288,6 +295,9 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") conf.set(CometConf.COMET_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") + conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true") + conf.set(CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key, "false") + conf.set(CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.key, "true") conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "1g") conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")