diff --git a/native/proto/src/proto/datatype.proto b/native/proto/src/proto/datatype.proto new file mode 100644 index 0000000000..04bd05ec77 --- /dev/null +++ b/native/proto/src/proto/datatype.proto @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + + + +syntax = "proto3"; + +package spark.spark_expression; + +option java_package = "org.apache.comet.serde"; + +message DataType { + enum DataTypeId { + BOOL = 0; + INT8 = 1; + INT16 = 2; + INT32 = 3; + INT64 = 4; + FLOAT = 5; + DOUBLE = 6; + STRING = 7; + BYTES = 8; + TIMESTAMP = 9; + DECIMAL = 10; + TIMESTAMP_NTZ = 11; + DATE = 12; + NULL = 13; + LIST = 14; + MAP = 15; + STRUCT = 16; + } + DataTypeId type_id = 1; + + message DataTypeInfo { + oneof datatype_struct { + DecimalInfo decimal = 2; + ListInfo list = 3; + MapInfo map = 4; + StructInfo struct = 5; + } + } + + message DecimalInfo { + int32 precision = 1; + int32 scale = 2; + } + + message ListInfo { + DataType element_type = 1; + bool contains_null = 2; + } + + message MapInfo { + DataType key_type = 1; + DataType value_type = 2; + bool value_contains_null = 3; + } + + message StructInfo { + repeated string field_names = 1; + repeated DataType field_datatypes = 2; + repeated bool field_nullable = 3; + } + + DataTypeInfo type_info = 2; +} diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 1152d7a1b2..bdd2c03ff2 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -21,6 +21,8 @@ syntax = "proto3"; package spark.spark_expression; +import "datatype.proto"; +import "literal.proto"; import "types.proto"; option java_package = "org.apache.comet.serde"; @@ -203,27 +205,6 @@ message BloomFilterAgg { DataType datatype = 4; } -message Literal { - oneof value { - bool bool_val = 1; - // Protobuf doesn't provide int8 and int16, we put them into int32 and convert - // to int8 and int16 when deserializing. - int32 byte_val = 2; - int32 short_val = 3; - int32 int_val = 4; - int64 long_val = 5; - float float_val = 6; - double double_val = 7; - string string_val = 8; - bytes bytes_val = 9; - bytes decimal_val = 10; - ListLiteral list_val = 11; - } - - DataType datatype = 12; - bool is_null = 13; -} - enum EvalMode { LEGACY = 0; TRY = 1; @@ -426,59 +407,3 @@ message ArrayJoin { message Rand { int64 seed = 1; } - -message DataType { - enum DataTypeId { - BOOL = 0; - INT8 = 1; - INT16 = 2; - INT32 = 3; - INT64 = 4; - FLOAT = 5; - DOUBLE = 6; - STRING = 7; - BYTES = 8; - TIMESTAMP = 9; - DECIMAL = 10; - TIMESTAMP_NTZ = 11; - DATE = 12; - NULL = 13; - LIST = 14; - MAP = 15; - STRUCT = 16; - } - DataTypeId type_id = 1; - - message DataTypeInfo { - oneof datatype_struct { - DecimalInfo decimal = 2; - ListInfo list = 3; - MapInfo map = 4; - StructInfo struct = 5; - } - } - - message DecimalInfo { - int32 precision = 1; - int32 scale = 2; - } - - message ListInfo { - DataType element_type = 1; - bool contains_null = 2; - } - - message MapInfo { - DataType key_type = 1; - DataType value_type = 2; - bool value_contains_null = 3; - } - - message StructInfo { - repeated string field_names = 1; - repeated DataType field_datatypes = 2; - repeated bool field_nullable = 3; - } - - DataTypeInfo type_info = 2; -} \ No newline at end of file diff --git a/native/proto/src/proto/literal.proto b/native/proto/src/proto/literal.proto new file mode 100644 index 0000000000..bfff137bc8 --- /dev/null +++ b/native/proto/src/proto/literal.proto @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + + + +syntax = "proto3"; + +package spark.spark_expression; + +import "datatype.proto"; + +option java_package = "org.apache.comet.serde"; + +message Literal { + oneof value { + bool bool_val = 1; + // Protobuf doesn't provide int8 and int16, we put them into int32 and convert + // to int8 and int16 when deserializing. + int32 byte_val = 2; + int32 short_val = 3; + int32 int_val = 4; + int64 long_val = 5; + float float_val = 6; + double double_val = 7; + string string_val = 8; + bytes bytes_val = 9; + bytes decimal_val = 10; + } + + DataType datatype = 11; + bool is_null = 12; +} diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 5cb332ef03..0f1feb199f 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -21,6 +21,7 @@ syntax = "proto3"; package spark.spark_operator; +import "datatype.proto"; import "expr.proto"; import "partitioning.proto"; diff --git a/spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala b/spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala index 4ad467cd80..ac6a89ca3b 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala @@ -29,13 +29,14 @@ import org.apache.spark.sql.types._ import org.apache.comet.serde.ExprOuterClass import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.LiteralOuterClass import org.apache.comet.serde.QueryPlanSerde.serializeDataType object SourceFilterSerde extends Logging { def createNameExpr( name: String, - schema: StructType): Option[(DataType, ExprOuterClass.Expr)] = { + schema: StructType): Option[(org.apache.spark.sql.types.DataType, ExprOuterClass.Expr)] = { val filedWithIndex = schema.fields.zipWithIndex.find { case (field, _) => field.name == name } @@ -66,8 +67,10 @@ object SourceFilterSerde extends Logging { /** * create a literal value native expression for source filter value, the value is a scala value */ - def createValueExpr(value: Any, dataType: DataType): Option[ExprOuterClass.Expr] = { - val exprBuilder = ExprOuterClass.Literal.newBuilder() + def createValueExpr( + value: Any, + dataType: org.apache.spark.sql.types.DataType): Option[ExprOuterClass.Expr] = { + val exprBuilder = LiteralOuterClass.Literal.newBuilder() var valueIsSet = true if (value == null) { exprBuilder.setIsNull(true) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 23cf9d313e..d5984ddf98 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -52,8 +52,9 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.expressions._ import org.apache.comet.objectstore.NativeConfig -import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} -import org.apache.comet.serde.ExprOuterClass.DataType._ +import org.apache.comet.serde.Datatype.{DataType => ProtoDataType} +import org.apache.comet.serde.Datatype.DataType._ +import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} import org.apache.comet.serde.Types.ListLiteral @@ -213,7 +214,7 @@ object QueryPlanSerde extends Logging with CometExprShim { * doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return * false for it. */ - def serializeDataType(dt: DataType): Option[ExprOuterClass.DataType] = { + def serializeDataType(dt: org.apache.spark.sql.types.DataType): Option[Datatype.DataType] = { val typeId = dt match { case _: BooleanType => 0 case _: ByteType => 1 @@ -728,7 +729,7 @@ object QueryPlanSerde extends Logging with CometExprShim { .contains(CometConf.COMET_NATIVE_SCAN_IMPL.get()) && dataType .isInstanceOf[ArrayType]) && !isComplexType( dataType.asInstanceOf[ArrayType].elementType)) => - val exprBuilder = ExprOuterClass.Literal.newBuilder() + val exprBuilder = LiteralOuterClass.Literal.newBuilder() if (value == null) { exprBuilder.setIsNull(true) diff --git a/spark/src/main/scala/org/apache/comet/serde/hash.scala b/spark/src/main/scala/org/apache/comet/serde/hash.scala index 4996c3a34e..587a6e0b13 100644 --- a/spark/src/main/scala/org/apache/comet/serde/hash.scala +++ b/spark/src/main/scala/org/apache/comet/serde/hash.scala @@ -34,7 +34,7 @@ object CometXxHash64 extends CometExpressionSerde[XxHash64] { return None } val exprs = expr.children.map(exprToProtoInternal(_, inputs, binding)) - val seedBuilder = ExprOuterClass.Literal + val seedBuilder = LiteralOuterClass.Literal .newBuilder() .setDatatype(serializeDataType(LongType).get) .setLongVal(expr.seed) @@ -53,7 +53,7 @@ object CometMurmur3Hash extends CometExpressionSerde[Murmur3Hash] { return None } val exprs = expr.children.map(exprToProtoInternal(_, inputs, binding)) - val seedBuilder = ExprOuterClass.Literal + val seedBuilder = LiteralOuterClass.Literal .newBuilder() .setDatatype(serializeDataType(IntegerType).get) .setIntVal(expr.seed)