Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions common/src/main/java/org/apache/comet/parquet/Native.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,10 @@ public static native void setPageV2(
* @param filePath
* @param start
* @param length
* @param required_columns array of names of fields to read
* @return a handle to the record batch reader, used in subsequent calls.
*/
public static native long initRecordBatchReader(
String filePath, long start, long length, Object[] required_columns);

public static native int numRowGroups(long handle);

public static native long numTotalRows(long handle);
String filePath, long fileSize, long start, long length, byte[] requiredSchema);

// arrow native version of read batch
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

package org.apache.comet.parquet;

import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URISyntaxException;
import java.nio.channels.Channels;
import java.util.*;

import scala.Option;
Expand All @@ -36,6 +38,9 @@
import org.apache.arrow.c.CometSchemaImporter;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.InputSplit;
Expand All @@ -52,6 +57,7 @@
import org.apache.spark.TaskContext$;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.comet.CometArrowUtils;
import org.apache.spark.sql.comet.parquet.CometParquetReadSupport;
import org.apache.spark.sql.execution.datasources.PartitionedFile;
import org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter;
Expand Down Expand Up @@ -99,7 +105,6 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme
private PartitionedFile file;
private final Map<String, SQLMetric> metrics;

private long rowsRead;
private StructType sparkSchema;
private MessageType requestedSchema;
private CometVector[] vectors;
Expand All @@ -111,9 +116,6 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme
private boolean isInitialized;
private ParquetMetadata footer;

/** The total number of rows across all row groups of the input split. */
private long totalRowCount;

/**
* Whether the native scan should always return decimal represented by 128 bits, regardless of its
* precision. Normally, this should be true if native execution is enabled, since Arrow compute
Expand Down Expand Up @@ -224,6 +226,7 @@ public void init() throws URISyntaxException, IOException {
long start = file.start();
long length = file.length();
String filePath = file.filePath().toString();
long fileSize = file.fileSize();

requestedSchema = footer.getFileMetaData().getSchema();
MessageType fileSchema = requestedSchema;
Expand Down Expand Up @@ -254,6 +257,13 @@ public void init() throws URISyntaxException, IOException {
}
} ////// End get requested schema

String timeZoneId = conf.get("spark.sql.session.timeZone");
Schema arrowSchema = CometArrowUtils.toArrowSchema(sparkSchema, timeZoneId);
ByteArrayOutputStream out = new ByteArrayOutputStream();
WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out));
MessageSerializer.serialize(writeChannel, arrowSchema);
byte[] serializedRequestedArrowSchema = out.toByteArray();

//// Create Column readers
List<ColumnDescriptor> columns = requestedSchema.getColumns();
int numColumns = columns.size();
Expand Down Expand Up @@ -334,13 +344,9 @@ public void init() throws URISyntaxException, IOException {
}
}

// TODO: (ARROW NATIVE) Use a ProjectionMask here ?
ArrayList<String> requiredColumns = new ArrayList<>();
for (Type col : requestedSchema.asGroupType().getFields()) {
requiredColumns.add(col.getName());
}
this.handle = Native.initRecordBatchReader(filePath, start, length, requiredColumns.toArray());
totalRowCount = Native.numRowGroups(handle);
this.handle =
Native.initRecordBatchReader(
filePath, fileSize, start, length, serializedRequestedArrowSchema);
isInitialized = true;
}

Expand Down Expand Up @@ -375,7 +381,7 @@ public ColumnarBatch getCurrentValue() {

@Override
public float getProgress() {
return (float) rowsRead / totalRowCount;
return 0;
}

/**
Expand All @@ -395,7 +401,7 @@ public ColumnarBatch currentBatch() {
public boolean nextBatch() throws IOException {
Preconditions.checkState(isInitialized, "init() should be called first!");

if (rowsRead >= totalRowCount) return false;
// if (rowsRead >= totalRowCount) return false;
int batchSize;

try {
Expand Down Expand Up @@ -432,7 +438,6 @@ public boolean nextBatch() throws IOException {
}

currentBatch.setNumRows(batchSize);
rowsRead += batchSize;
return true;
}

Expand All @@ -457,6 +462,9 @@ private int loadNextBatch() throws Throwable {
long startNs = System.nanoTime();

int batchSize = Native.readNextRecordBatch(this.handle);
if (batchSize == 0) {
return batchSize;
}
if (importer != null) importer.close();
importer = new CometSchemaImporter(ALLOCATOR);

Expand Down
12 changes: 6 additions & 6 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ object CometConf extends ShimCometConf {
"that to enable native vectorized execution, both this config and " +
"'spark.comet.exec.enabled' need to be enabled.")
.booleanConf
.createWithDefault(true)
.createWithDefault(false)

val COMET_FULL_NATIVE_SCAN_ENABLED: ConfigEntry[Boolean] = conf(
"spark.comet.native.scan.enabled")
Expand All @@ -85,15 +85,15 @@ object CometConf extends ShimCometConf {
"read supported data sources (currently only Parquet is supported natively)." +
" By default, this config is true.")
.booleanConf
.createWithDefault(false)
.createWithDefault(true)

val COMET_NATIVE_ARROW_SCAN_ENABLED: ConfigEntry[Boolean] = conf(
val COMET_NATIVE_RECORDBATCH_READER_ENABLED: ConfigEntry[Boolean] = conf(
"spark.comet.native.arrow.scan.enabled")
.internal()
.doc(
"Whether to enable the fully native arrow based scan. When this is turned on, Spark will " +
"use Comet to read Parquet files natively via the Arrow based Parquet reader." +
" By default, this config is false.")
"Whether to enable the fully native datafusion based column reader. When this is turned on," +
" Spark will use Comet to read Parquet files natively via the Datafusion based Parquet" +
" reader. By default, this config is false.")
.booleanConf
.createWithDefault(false)

Expand Down
180 changes: 180 additions & 0 deletions common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet

import scala.collection.JavaConverters._

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

object CometArrowUtils {

val rootAllocator = new RootAllocator(Long.MaxValue)

// todo: support more types.

/** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */
def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match {
case BooleanType => ArrowType.Bool.INSTANCE
case ByteType => new ArrowType.Int(8, true)
case ShortType => new ArrowType.Int(8 * 2, true)
case IntegerType => new ArrowType.Int(8 * 4, true)
case LongType => new ArrowType.Int(8 * 8, true)
case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
case StringType => ArrowType.Utf8.INSTANCE
case BinaryType => ArrowType.Binary.INSTANCE
case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale)
case DateType => new ArrowType.Date(DateUnit.DAY)
case TimestampType if timeZoneId == null =>
throw new IllegalStateException("Missing timezoneId where it is mandatory.")
case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
case TimestampNTZType =>
new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
case NullType => ArrowType.Null.INSTANCE
case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
case _ =>
throw new IllegalArgumentException()
}

def fromArrowType(dt: ArrowType): DataType = dt match {
case ArrowType.Bool.INSTANCE => BooleanType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType
case float: ArrowType.FloatingPoint
if float.getPrecision() == FloatingPointPrecision.SINGLE =>
FloatType
case float: ArrowType.FloatingPoint
if float.getPrecision() == FloatingPointPrecision.DOUBLE =>
DoubleType
case ArrowType.Utf8.INSTANCE => StringType
case ArrowType.Binary.INSTANCE => BinaryType
case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
case ts: ArrowType.Timestamp
if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
TimestampNTZType
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
case ArrowType.Null.INSTANCE => NullType
case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
YearMonthIntervalType()
case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType()
case _ => throw new IllegalArgumentException()
// throw QueryExecutionErrors.unsupportedArrowTypeError(dt)
}

/** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */
def toArrowField(name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = {
dt match {
case ArrayType(elementType, containsNull) =>
val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
new Field(
name,
fieldType,
Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava)
case StructType(fields) =>
val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
new Field(
name,
fieldType,
fields
.map { field =>
toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
}
.toSeq
.asJava)
case MapType(keyType, valueType, valueContainsNull) =>
val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
// Note: Map Type struct can not be null, Struct Type key field can not be null
new Field(
name,
mapType,
Seq(
toArrowField(
MapVector.DATA_VECTOR_NAME,
new StructType()
.add(MapVector.KEY_NAME, keyType, nullable = false)
.add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull),
nullable = false,
timeZoneId)).asJava)
case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, nullable, timeZoneId)
case dataType =>
val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null)
new Field(name, fieldType, Seq.empty[Field].asJava)
}
}

def fromArrowField(field: Field): DataType = {
field.getType match {
case _: ArrowType.Map =>
val elementField = field.getChildren.get(0)
val keyType = fromArrowField(elementField.getChildren.get(0))
val valueType = fromArrowField(elementField.getChildren.get(1))
MapType(keyType, valueType, elementField.getChildren.get(1).isNullable)
case ArrowType.List.INSTANCE =>
val elementField = field.getChildren().get(0)
val elementType = fromArrowField(elementField)
ArrayType(elementType, containsNull = elementField.isNullable)
case ArrowType.Struct.INSTANCE =>
val fields = field.getChildren().asScala.map { child =>
val dt = fromArrowField(child)
StructField(child.getName, dt, child.isNullable)
}
StructType(fields.toArray)
case arrowType => fromArrowType(arrowType)
}
}

/**
* Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType
*/
def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
new Schema(schema.map { field =>
toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
}.asJava)
}

def fromArrowSchema(schema: Schema): StructType = {
StructType(schema.getFields.asScala.map { field =>
val dt = fromArrowField(field)
StructField(field.getName, dt, field.isNullable)
}.toArray)
}

/** Return Map with conf settings to be used in ArrowPythonRunner */
def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone)
val pandasColsByName = Seq(
SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
conf.pandasGroupedMapAssignColumnsByName.toString)
val arrowSafeTypeCheck = Seq(
SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
conf.arrowSafeTypeConversion.toString)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
}

}
2 changes: 2 additions & 0 deletions native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions native/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ arrow = { version = "53.2.0", features = ["prettyprint", "ffi", "chrono-tz"] }
arrow-array = { version = "53.2.0" }
arrow-buffer = { version = "53.2.0" }
arrow-data = { version = "53.2.0" }
arrow-ipc = { version = "53.2.0" }
arrow-schema = { version = "53.2.0" }
flatbuffers = { version = "24.3.25" }
parquet = { version = "53.2.0", default-features = false, features = ["experimental"] }
datafusion-common = { version = "43.0.0" }
datafusion = { version = "43.0.0", default-features = false, features = ["unicode_expressions", "crypto_expressions", "parquet"] }
Expand Down
2 changes: 2 additions & 0 deletions native/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
arrow-data = { workspace = true }
arrow-schema = { workspace = true }
arrow-ipc = { workspace = true }
flatbuffers = { workspace = true }
parquet = { workspace = true, default-features = false, features = ["experimental"] }
half = { version = "2.4.1", default-features = false }
futures = "0.3.28"
Expand Down
2 changes: 1 addition & 1 deletion native/core/src/execution/datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
pub mod expressions;
mod operators;
pub mod planner;
mod schema_adapter;
pub(crate) mod schema_adapter;
pub mod shuffle_writer;
mod util;
Loading
Loading