diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 508a11a01c85..4ecfb7e3fed1 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -479,8 +479,7 @@ class Expression(google.protobuf.message.Message): def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: """(Deprecated) The element type of the array. - This field is deprecated since Spark 4.1+ and should only be set - if the data_type field is not set. Use data_type field instead. + This field is deprecated since Spark 4.1+. Use data_type field instead. """ @property def elements( @@ -488,14 +487,19 @@ class Expression(google.protobuf.message.Message): ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ global___Expression.Literal ]: - """The literal values that make up the array elements.""" + """The literal values that make up the array elements. + + For inferring the data_type.element_type, only the first element needs to + contain the type information. + """ @property def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Array: - """The type of the array. + """The type of the array. You don't need to set this field if the type information is not needed. If the element type can be inferred from the first element of the elements field, - then you don't need to set data_type.element_type to save space. On the other hand, - redundant type information is also acceptable. + then you don't need to set data_type.element_type to save space. + + On the other hand, redundant type information is also acceptable. """ def __init__( self, @@ -534,8 +538,7 @@ class Expression(google.protobuf.message.Message): def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: """(Deprecated) The key type of the map. - This field is deprecated since Spark 4.1+ and should only be set - if the data_type field is not set. Use data_type field instead. + This field is deprecated since Spark 4.1+. Use data_type field instead. """ @property def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: @@ -550,20 +553,29 @@ class Expression(google.protobuf.message.Message): ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ global___Expression.Literal ]: - """The literal keys that make up the map.""" + """The literal keys that make up the map. + + For inferring the data_type.key_type, only the first key needs to + contain the type information. + """ @property def values( self, ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ global___Expression.Literal ]: - """The literal values that make up the map.""" + """The literal values that make up the map. + + For inferring the data_type.value_type, only the first value needs to + contain the type information. + """ @property def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Map: - """The type of the map. + """The type of the map. You don't need to set this field if the type information is not needed. If the key/value types can be inferred from the first element of the keys/values fields, then you don't need to set data_type.key_type/data_type.value_type to save space. + On the other hand, redundant type information is also acceptable. """ def __init__( @@ -608,8 +620,7 @@ class Expression(google.protobuf.message.Message): """(Deprecated) The type of the struct. This field is deprecated since Spark 4.1+ because using DataType as the type of a struct - is ambiguous. This field should only be set if the data_type_struct field is not set. - Use data_type_struct field instead. + is ambiguous. Use data_type_struct field instead. """ @property def elements( @@ -620,7 +631,7 @@ class Expression(google.protobuf.message.Message): """(Required) The literal values that make up the struct elements.""" @property def data_type_struct(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Struct: - """The type of the struct. + """The type of the struct. You don't need to set this field if the type information is not needed. Whether data_type_struct.fields.data_type should be set depends on whether each field's type can be inferred from the elements field. diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index b760828a1e99..76572e41956f 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -3419,6 +3419,11 @@ class PlanGenerationTestSuite mutable.LinkedHashMap("a" -> 1, "b" -> 2), mutable.LinkedHashMap("a" -> 3, "b" -> 4), mutable.LinkedHashMap("a" -> 5, "b" -> 6))), + fn.typedLit( + Seq( + mutable.LinkedHashMap("a" -> Seq("1", "2"), "b" -> Seq("3", "4")), + mutable.LinkedHashMap("a" -> Seq("5", "6"), "b" -> Seq("7", "8")), + mutable.LinkedHashMap("a" -> Seq.empty[String], "b" -> Seq.empty[String]))), fn.typedLit( mutable.LinkedHashMap( 1 -> mutable.LinkedHashMap("a" -> 1, "b" -> 2), diff --git a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto index 913622b91a28..9bbec678b44f 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -217,26 +217,28 @@ message Expression { message Array { // (Deprecated) The element type of the array. // - // This field is deprecated since Spark 4.1+ and should only be set - // if the data_type field is not set. Use data_type field instead. + // This field is deprecated since Spark 4.1+. Use data_type field instead. DataType element_type = 1 [deprecated = true]; // The literal values that make up the array elements. + // + // For inferring the data_type.element_type, only the first element needs to + // contain the type information. repeated Literal elements = 2; - // The type of the array. + // The type of the array. You don't need to set this field if the type information is not needed. // // If the element type can be inferred from the first element of the elements field, - // then you don't need to set data_type.element_type to save space. On the other hand, - // redundant type information is also acceptable. + // then you don't need to set data_type.element_type to save space. + // + // On the other hand, redundant type information is also acceptable. DataType.Array data_type = 3; } message Map { // (Deprecated) The key type of the map. // - // This field is deprecated since Spark 4.1+ and should only be set - // if the data_type field is not set. Use data_type field instead. + // This field is deprecated since Spark 4.1+. Use data_type field instead. DataType key_type = 1 [deprecated = true]; // (Deprecated) The value type of the map. @@ -246,15 +248,22 @@ message Expression { DataType value_type = 2 [deprecated = true]; // The literal keys that make up the map. + // + // For inferring the data_type.key_type, only the first key needs to + // contain the type information. repeated Literal keys = 3; // The literal values that make up the map. + // + // For inferring the data_type.value_type, only the first value needs to + // contain the type information. repeated Literal values = 4; - // The type of the map. + // The type of the map. You don't need to set this field if the type information is not needed. // // If the key/value types can be inferred from the first element of the keys/values fields, // then you don't need to set data_type.key_type/data_type.value_type to save space. + // // On the other hand, redundant type information is also acceptable. DataType.Map data_type = 5; } @@ -263,14 +272,13 @@ message Expression { // (Deprecated) The type of the struct. // // This field is deprecated since Spark 4.1+ because using DataType as the type of a struct - // is ambiguous. This field should only be set if the data_type_struct field is not set. - // Use data_type_struct field instead. + // is ambiguous. Use data_type_struct field instead. DataType struct_type = 1 [deprecated = true]; // (Required) The literal values that make up the struct elements. repeated Literal elements = 2; - // The type of the struct. + // The type of the struct. You don't need to set this field if the type information is not needed. // // Whether data_type_struct.fields.data_type should be set depends on // whether each field's type can be inferred from the elements field. diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index d9f860907b47..ca785ff098c6 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -39,10 +39,51 @@ import org.apache.spark.util.SparkClassUtils object LiteralValueProtoConverter { - @scala.annotation.tailrec + private def setArrayTypeAfterAddingElements( + ab: proto.Expression.Literal.Array.Builder, + elementType: DataType, + containsNull: Boolean, + useDeprecatedDataTypeFields: Boolean, + needDataType: Boolean): Unit = { + if (useDeprecatedDataTypeFields) { + ab.setElementType(toConnectProtoType(elementType)) + } else if (needDataType) { + val dataTypeBuilder = proto.DataType.Array.newBuilder() + if (ab.getElementsCount == 0 || getInferredDataType(ab.getElements(0)).isEmpty) { + dataTypeBuilder.setElementType(toConnectProtoType(elementType)) + } + dataTypeBuilder.setContainsNull(containsNull) + ab.setDataType(dataTypeBuilder.build()) + } + } + + private def setMapTypeAfterAddingKeysAndValues( + mb: proto.Expression.Literal.Map.Builder, + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean, + useDeprecatedDataTypeFields: Boolean, + needDataType: Boolean): Unit = { + if (useDeprecatedDataTypeFields) { + mb.setKeyType(toConnectProtoType(keyType)) + mb.setValueType(toConnectProtoType(valueType)) + } else if (needDataType) { + val dataTypeBuilder = proto.DataType.Map.newBuilder() + if (mb.getKeysCount == 0 || getInferredDataType(mb.getKeys(0)).isEmpty) { + dataTypeBuilder.setKeyType(toConnectProtoType(keyType)) + } + if (mb.getValuesCount == 0 || getInferredDataType(mb.getValues(0)).isEmpty) { + dataTypeBuilder.setValueType(toConnectProtoType(valueType)) + } + dataTypeBuilder.setValueContainsNull(valueContainsNull) + mb.setDataType(dataTypeBuilder.build()) + } + } + private def toLiteralProtoBuilderInternal( literal: Any, - options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { + options: ToLiteralProtoOptions, + needDataType: Boolean): proto.Expression.Literal.Builder = { val builder = proto.Expression.Literal.newBuilder() def decimalBuilder(precision: Int, scale: Int, value: String) = { @@ -58,17 +99,17 @@ object LiteralValueProtoConverter { def arrayBuilder(array: Array[_]) = { val ab = builder.getArrayBuilder - if (options.useDeprecatedDataTypeFields) { - ab.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType))) - } else { - ab.setDataType( - proto.DataType.Array - .newBuilder() - .setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType))) - .setContainsNull(true) - .build()) + var needElementType = needDataType + array.foreach { x => + ab.addElements(toLiteralProtoBuilderInternal(x, options, needElementType).build()) + needElementType = false } - array.foreach(x => ab.addElements(toLiteralProtoWithOptions(x, None, options))) + setArrayTypeAfterAddingElements( + ab, + toDataType(array.getClass.getComponentType), + containsNull = true, + options.useDeprecatedDataTypeFields, + needDataType) ab } @@ -88,8 +129,9 @@ object LiteralValueProtoConverter { case v: Char => builder.setString(v.toString) case v: Array[Char] => builder.setString(String.valueOf(v)) case v: Array[Byte] => builder.setBinary(ByteString.copyFrom(v)) - case v: mutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.array, options) - case v: immutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.unsafeArray, options) + case v: mutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.array, options, needDataType) + case v: immutable.ArraySeq[_] => + toLiteralProtoBuilderInternal(v.unsafeArray, options, needDataType) case v: LocalDate => builder.setDate(v.toEpochDay.toInt) case v: Decimal => builder.setDecimal(decimalBuilder(Math.max(v.precision, v.scale), v.scale, v.toString)) @@ -113,36 +155,38 @@ object LiteralValueProtoConverter { } } - @scala.annotation.tailrec private def toLiteralProtoBuilderInternal( literal: Any, dataType: DataType, - options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { + options: ToLiteralProtoOptions, + needDataType: Boolean): proto.Expression.Literal.Builder = { val builder = proto.Expression.Literal.newBuilder() def arrayBuilder(scalaValue: Any, elementType: DataType, containsNull: Boolean) = { val ab = builder.getArrayBuilder - if (options.useDeprecatedDataTypeFields) { - ab.setElementType(toConnectProtoType(elementType)) - } else { - ab.setDataType( - proto.DataType.Array - .newBuilder() - .setElementType(toConnectProtoType(elementType)) - .setContainsNull(containsNull) - .build()) - } + var needElementType = needDataType scalaValue match { case a: Array[_] => - a.foreach(item => - ab.addElements(toLiteralProtoWithOptions(item, Some(elementType), options))) + a.foreach { item => + ab.addElements( + toLiteralProtoBuilderInternal(item, elementType, options, needElementType).build()) + needElementType = false + } case s: scala.collection.Seq[_] => - s.foreach(item => - ab.addElements(toLiteralProtoWithOptions(item, Some(elementType), options))) + s.foreach { item => + ab.addElements( + toLiteralProtoBuilderInternal(item, elementType, options, needElementType).build()) + needElementType = false + } case other => throw new IllegalArgumentException(s"literal $other not supported (yet).") } - + setArrayTypeAfterAddingElements( + ab, + elementType, + containsNull, + options.useDeprecatedDataTypeFields, + needDataType) ab } @@ -152,29 +196,26 @@ object LiteralValueProtoConverter { valueType: DataType, valueContainsNull: Boolean) = { val mb = builder.getMapBuilder - if (options.useDeprecatedDataTypeFields) { - mb.setKeyType(toConnectProtoType(keyType)) - mb.setValueType(toConnectProtoType(valueType)) - } else { - mb.setDataType( - proto.DataType.Map - .newBuilder() - .setKeyType(toConnectProtoType(keyType)) - .setValueType(toConnectProtoType(valueType)) - .setValueContainsNull(valueContainsNull) - .build()) - } - + var needKeyAndValueType = needDataType scalaValue match { case map: scala.collection.Map[_, _] => map.foreach { case (k, v) => - mb.addKeys(toLiteralProtoWithOptions(k, Some(keyType), options)) - mb.addValues(toLiteralProtoWithOptions(v, Some(valueType), options)) + mb.addKeys( + toLiteralProtoBuilderInternal(k, keyType, options, needKeyAndValueType).build()) + mb.addValues( + toLiteralProtoBuilderInternal(v, valueType, options, needKeyAndValueType).build()) + needKeyAndValueType = false } case other => throw new IllegalArgumentException(s"literal $other not supported (yet).") } - + setMapTypeAfterAddingKeysAndValues( + mb, + keyType, + valueType, + valueContainsNull, + options.useDeprecatedDataTypeFields, + needDataType) mb } @@ -189,37 +230,42 @@ object LiteralValueProtoConverter { if (options.useDeprecatedDataTypeFields) { while (idx < structType.size) { val field = fields(idx) - val literalProto = - toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options) + // For backward compatibility, we need the data type for each field. + val literalProto = toLiteralProtoBuilderInternal( + iter.next(), + field.dataType, + options, + needDataType = true).build() sb.addElements(literalProto) idx += 1 } sb.setStructType(toConnectProtoType(structType)) } else { - val dataTypeStruct = proto.DataType.Struct.newBuilder() while (idx < structType.size) { val field = fields(idx) val literalProto = - toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options) + toLiteralProtoBuilderInternal(iter.next(), field.dataType, options, needDataType) + .build() sb.addElements(literalProto) - val fieldBuilder = dataTypeStruct - .addFieldsBuilder() - .setName(field.name) - .setNullable(field.nullable) + if (needDataType) { + val fieldBuilder = sb.getDataTypeStructBuilder + .addFieldsBuilder() + .setName(field.name) + .setNullable(field.nullable) - if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) { - fieldBuilder.setDataType(toConnectProtoType(field.dataType)) - } + if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) { + fieldBuilder.setDataType(toConnectProtoType(field.dataType)) + } - // Set metadata if available - if (field.metadata != Metadata.empty) { - fieldBuilder.setMetadata(field.metadata.json) + // Set metadata if available + if (field.metadata != Metadata.empty) { + fieldBuilder.setMetadata(field.metadata.json) + } } idx += 1 } - sb.setDataTypeStruct(dataTypeStruct.build()) } case other => throw new IllegalArgumentException(s"literal $other not supported (yet).") @@ -230,11 +276,11 @@ object LiteralValueProtoConverter { (literal, dataType) match { case (v: mutable.ArraySeq[_], ArrayType(_, _)) => - toLiteralProtoBuilderInternal(v.array, dataType, options) + toLiteralProtoBuilderInternal(v.array, dataType, options, needDataType) case (v: immutable.ArraySeq[_], ArrayType(_, _)) => - toLiteralProtoBuilderInternal(v.unsafeArray, dataType, options) + toLiteralProtoBuilderInternal(v.unsafeArray, dataType, options, needDataType) case (v: Array[Byte], ArrayType(_, _)) => - toLiteralProtoBuilderInternal(v, options) + toLiteralProtoBuilderInternal(v, options, needDataType) case (v, ArrayType(elementType, containsNull)) => builder.setArray(arrayBuilder(v, elementType, containsNull)) case (v, MapType(keyType, valueType, valueContainsNull)) => @@ -243,7 +289,7 @@ object LiteralValueProtoConverter { builder.setStruct(structBuilder(v, structType)) case (v: Option[_], _: DataType) => if (v.isDefined) { - toLiteralProtoBuilderInternal(v.get, options) + toLiteralProtoBuilderInternal(v.get, options, needDataType) } else { builder.setNull(toConnectProtoType(dataType)) } @@ -252,7 +298,7 @@ object LiteralValueProtoConverter { builder.getTimeBuilder .setNano(SparkDateTimeUtils.localTimeToNanos(v)) .setPrecision(timeType.precision)) - case _ => toLiteralProtoBuilderInternal(literal, options) + case _ => toLiteralProtoBuilderInternal(literal, options, needDataType) } } @@ -266,7 +312,8 @@ object LiteralValueProtoConverter { def toLiteralProtoBuilder(literal: Any): proto.Expression.Literal.Builder = { toLiteralProtoBuilderInternal( literal, - ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true), + needDataType = true) } def toLiteralProtoBuilder( @@ -275,7 +322,8 @@ object LiteralValueProtoConverter { toLiteralProtoBuilderInternal( literal, dataType, - ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true), + needDataType = true) } def toLiteralProtoBuilderWithOptions( @@ -284,9 +332,9 @@ object LiteralValueProtoConverter { options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { dataTypeOpt match { case Some(dataType) => - toLiteralProtoBuilderInternal(literal, dataType, options) + toLiteralProtoBuilderInternal(literal, dataType, options, needDataType = true) case None => - toLiteralProtoBuilderInternal(literal, options) + toLiteralProtoBuilderInternal(literal, options, needDataType = true) } } @@ -308,13 +356,15 @@ object LiteralValueProtoConverter { def toLiteralProto(literal: Any): proto.Expression.Literal = toLiteralProtoBuilderInternal( literal, - ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)).build() + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true), + needDataType = true).build() def toLiteralProto(literal: Any, dataType: DataType): proto.Expression.Literal = toLiteralProtoBuilderInternal( literal, dataType, - ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)).build() + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true), + needDataType = true).build() def toLiteralProtoWithOptions( literal: Any, @@ -322,9 +372,9 @@ object LiteralValueProtoConverter { options: ToLiteralProtoOptions): proto.Expression.Literal = { dataTypeOpt match { case Some(dataType) => - toLiteralProtoBuilderInternal(literal, dataType, options).build() + toLiteralProtoBuilderInternal(literal, dataType, options, needDataType = true).build() case None => - toLiteralProtoBuilderInternal(literal, options).build() + toLiteralProtoBuilderInternal(literal, options, needDataType = true).build() } } @@ -414,6 +464,9 @@ object LiteralValueProtoConverter { case proto.Expression.Literal.LiteralTypeCase.ARRAY => toScalaArray(literal.getArray) + case proto.Expression.Literal.LiteralTypeCase.MAP => + toScalaMap(literal.getMap) + case proto.Expression.Literal.LiteralTypeCase.STRUCT => toScalaStruct(literal.getStruct) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain index 943b353a14cc..94871b4e66ef 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain @@ -1,2 +1,2 @@ -Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, 2023-02-23 AS DATE '2023-02-23'#0, INTERVAL '0 00:03:20' DAY TO SECOND AS INTERVAL '0 00:03:20' DAY TO SECOND#0, INTERVAL '0-0' YEAR TO MONTH AS INTERVAL '0-0' YEAR TO MONTH#0, 23:59:59.999999999 AS TIME '23:59:59.999999999'#0, 2 months 20 days 0.0001 seconds AS INTERVAL '2 months 20 days 0.0001 seconds'#0, 1 AS 1#0, [1,2,3] AS ARRAY(1, 2, 3)#0, [1,2,3] AS ARRAY(1, 2, 3)#0, map(keys: [a,b], values: [1,2]) AS MAP('a', 1, 'b', 2)#0, [a,2,1.0] AS NAMED_STRUCT('_1', 'a', '_2', 2, '_3', 1.0D)#0, null AS NULL#0, [1] AS ARRAY(1)#0, map(keys: [1], values: [0]) AS MAP(1, 0)#0, map(keys: [1], values: [0]) AS MAP(1, 0)#0, map(keys: [1], values: [0]) AS MAP(1, 0)#0, [[1,2,3],[4,5,6],[7,8,9]] AS ARRAY(ARRAY(1, 2, 3), ARRAY(4, 5, 6), ARRAY(7, 8, 9))#0, [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4],keys: [a,b], values: [5,6]] AS ARRAY(MAP('a', 1, 'b', 2), MAP('a', 3, 'b', 4), MAP('a', 5, 'b', 6))#0, map(keys: [1,2], values: [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4]]) AS MAP(1, MAP('a', 1, 'b', 2), 2, MAP('a', 3, 'b', 4))#0, [[1,2,3],keys: [a,b], values: [1,2],[a,keys: [1,2], values: [a,b]]] AS NAMED_STRUCT('_1', ARRAY(1, 2, 3), '_2', MAP('a', 1, 'b', 2), '_3', NAMED_STRUCT('_1', 'a', '_2', MAP(1, 'a', 2, 'b')))#0] +Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, 2023-02-23 AS DATE '2023-02-23'#0, INTERVAL '0 00:03:20' DAY TO SECOND AS INTERVAL '0 00:03:20' DAY TO SECOND#0, INTERVAL '0-0' YEAR TO MONTH AS INTERVAL '0-0' YEAR TO MONTH#0, 23:59:59.999999999 AS TIME '23:59:59.999999999'#0, 2 months 20 days 0.0001 seconds AS INTERVAL '2 months 20 days 0.0001 seconds'#0, 1 AS 1#0, [1,2,3] AS ARRAY(1, 2, 3)#0, [1,2,3] AS ARRAY(1, 2, 3)#0, map(keys: [a,b], values: [1,2]) AS MAP('a', 1, 'b', 2)#0, [a,2,1.0] AS NAMED_STRUCT('_1', 'a', '_2', 2, '_3', 1.0D)#0, null AS NULL#0, [1] AS ARRAY(1)#0, map(keys: [1], values: [0]) AS MAP(1, 0)#0, map(keys: [1], values: [0]) AS MAP(1, 0)#0, map(keys: [1], values: [0]) AS MAP(1, 0)#0, [[1,2,3],[4,5,6],[7,8,9]] AS ARRAY(ARRAY(1, 2, 3), ARRAY(4, 5, 6), ARRAY(7, 8, 9))#0, [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4],keys: [a,b], values: [5,6]] AS ARRAY(MAP('a', 1, 'b', 2), MAP('a', 3, 'b', 4), MAP('a', 5, 'b', 6))#0, [keys: [a,b], values: [[1,2],[3,4]],keys: [a,b], values: [[5,6],[7,8]],keys: [a,b], values: [[],[]]] AS ARRAY(MAP('a', ARRAY('1', '2'), 'b', ARRAY('3', '4')), MAP('a', ARRAY('5', '6'), 'b', ARRAY('7', '8')), MAP('a', ARRAY(), 'b', ARRAY()))#0, map(keys: [1,2], values: [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4]]) AS MAP(1, MAP('a', 1, 'b', 2), 2, MAP('a', 3, 'b', 4))#0, [[1,2,3],keys: [a,b], values: [1,2],[a,keys: [1,2], values: [a,b]]] AS NAMED_STRUCT('_1', ARRAY(1, 2, 3), '_2', MAP('a', 1, 'b', 2), '_3', NAMED_STRUCT('_1', 'a', '_2', MAP(1, 'a', 2, 'b')))#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json index cedf7572a1fd..a899c9f410aa 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json @@ -364,10 +364,6 @@ "integer": 6 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin index 5d30f4fca159..26c7b3a7dc02 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json index 53b1a7b3947f..153478ce75bb 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json @@ -49,10 +49,6 @@ "integer": 1 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } @@ -60,39 +56,16 @@ "array": { "elements": [{ "integer": 2 - }], - "dataType": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } + }] } }, { "array": { "elements": [{ "integer": 3 - }], - "dataType": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } + }] } }], "dataType": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, "containsNull": true } } @@ -125,24 +98,11 @@ "integer": 1 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } }], "dataType": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, "containsNull": true } } @@ -152,28 +112,9 @@ "array": { "elements": [{ "integer": 2 - }], - "dataType": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } + }] } - }], - "dataType": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, - "containsNull": true - } + }] } }, { "array": { @@ -181,45 +122,12 @@ "array": { "elements": [{ "integer": 3 - }], - "dataType": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } + }] } - }], - "dataType": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, - "containsNull": true - } + }] } }], "dataType": { - "elementType": { - "array": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, - "containsNull": true - } - }, "containsNull": true } } @@ -250,10 +158,6 @@ "boolean": false }], "dataType": { - "elementType": { - "boolean": { - } - }, "containsNull": true } } @@ -307,10 +211,6 @@ "short": 9874 }], "dataType": { - "elementType": { - "short": { - } - }, "containsNull": true } } @@ -343,10 +243,6 @@ "integer": -8726533 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } @@ -379,10 +275,6 @@ "long": "7834609328726533" }], "dataType": { - "elementType": { - "long": { - } - }, "containsNull": true } } @@ -415,10 +307,6 @@ "double": 2.0 }], "dataType": { - "elementType": { - "double": { - } - }, "containsNull": true } } @@ -451,10 +339,6 @@ "float": -0.9 }], "dataType": { - "elementType": { - "float": { - } - }, "containsNull": true } } @@ -664,10 +548,6 @@ "date": 18546 }], "dataType": { - "elementType": { - "date": { - } - }, "containsNull": true } } @@ -698,10 +578,6 @@ "timestamp": "1677155519809000" }], "dataType": { - "elementType": { - "timestamp": { - } - }, "containsNull": true } } @@ -732,10 +608,6 @@ "timestamp": "23456000" }], "dataType": { - "elementType": { - "timestamp": { - } - }, "containsNull": true } } @@ -766,10 +638,6 @@ "timestampNtz": "1677188160000000" }], "dataType": { - "elementType": { - "timestampNtz": { - } - }, "containsNull": true } } @@ -800,10 +668,6 @@ "date": 19417 }], "dataType": { - "elementType": { - "date": { - } - }, "containsNull": true } } @@ -914,10 +778,6 @@ } }], "dataType": { - "elementType": { - "calendarInterval": { - } - }, "containsNull": true } } diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin index 8cb965dd25a0..d9edb4100b0d 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index 66bf31d670f9..447d225a1ca9 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -409,10 +409,6 @@ "integer": 6 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, @@ -710,10 +706,6 @@ "integer": 3 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, @@ -745,10 +737,6 @@ "integer": 3 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, @@ -787,10 +775,6 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } @@ -887,10 +871,6 @@ "integer": 1 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } @@ -925,14 +905,6 @@ } }], "dataType": { - "keyType": { - "integer": { - } - }, - "valueType": { - "integer": { - } - }, "valueContainsNull": true } } @@ -967,14 +939,6 @@ } }], "dataType": { - "keyType": { - "integer": { - } - }, - "valueType": { - "integer": { - } - }, "valueContainsNull": true } } @@ -1009,14 +973,6 @@ } }], "dataType": { - "keyType": { - "integer": { - } - }, - "valueType": { - "integer": { - } - }, "valueContainsNull": true } } @@ -1051,10 +1007,6 @@ "integer": 3 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, { @@ -1065,13 +1017,7 @@ "integer": 5 }, { "integer": 6 - }], - "dataType": { - "elementType": { - "integer": { - } - } - } + }] } }, { "array": { @@ -1081,24 +1027,10 @@ "integer": 8 }, { "integer": 9 - }], - "dataType": { - "elementType": { - "integer": { - } - } - } + }] } }], "dataType": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - } - } - }, "containsNull": true } } @@ -1140,10 +1072,6 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } @@ -1158,18 +1086,7 @@ "integer": 3 }, { "integer": 4 - }], - "dataType": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueType": { - "integer": { - } - } - } + }] } }, { "map": { @@ -1182,6 +1099,65 @@ "integer": 5 }, { "integer": 6 + }] + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }, { + "literal": { + "array": { + "elements": [{ + "map": { + "keys": [{ + "string": "a" + }, { + "string": "b" + }], + "values": [{ + "array": { + "elements": [{ + "string": "1" + }, { + "string": "2" + }], + "dataType": { + "elementType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "containsNull": true + } + } + }, { + "array": { + "elements": [{ + "string": "3" + }, { + "string": "4" + }] + } }], "dataType": { "keyType": { @@ -1189,27 +1165,51 @@ "collation": "UTF8_BINARY" } }, - "valueType": { - "integer": { - } - } + "valueContainsNull": true } } + }, { + "map": { + "keys": [{ + "string": "a" + }, { + "string": "b" + }], + "values": [{ + "array": { + "elements": [{ + "string": "5" + }, { + "string": "6" + }] + } + }, { + "array": { + "elements": [{ + "string": "7" + }, { + "string": "8" + }] + } + }] + } + }, { + "map": { + "keys": [{ + "string": "a" + }, { + "string": "b" + }], + "values": [{ + "array": { + } + }, { + "array": { + } + }] + } }], "dataType": { - "elementType": { - "map": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueType": { - "integer": { - } - } - } - }, "containsNull": true } } @@ -1256,10 +1256,6 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } @@ -1274,38 +1270,10 @@ "integer": 3 }, { "integer": 4 - }], - "dataType": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueType": { - "integer": { - } - } - } + }] } }], "dataType": { - "keyType": { - "integer": { - } - }, - "valueType": { - "map": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueType": { - "integer": { - } - } - } - }, "valueContainsNull": true } } @@ -1340,10 +1308,6 @@ "integer": 3 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, { @@ -1363,10 +1327,6 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } @@ -1387,10 +1347,6 @@ "string": "b" }], "dataType": { - "keyType": { - "integer": { - } - }, "valueType": { "string": { "collation": "UTF8_BINARY" diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index b3ebe8a79e3e..9a8e1c1a10a6 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index 7f862d371a16..addf94ed3460 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -170,7 +170,7 @@ private[ml] object MLUtils { * @return * the reconciled array */ - private def reconcileArray(elementType: Class[_], array: Array[_]): Array[_] = { + private[ml] def reconcileArray(elementType: Class[_], array: Array[_]): Array[_] = { if (elementType == classOf[Byte]) { array.map(_.asInstanceOf[Byte]) } else if (elementType == classOf[Short]) { @@ -187,6 +187,8 @@ private[ml] object MLUtils { array.map(_.asInstanceOf[String]) } else if (elementType.isArray && elementType.getComponentType == classOf[Double]) { array.map(_.asInstanceOf[Array[_]].map(_.asInstanceOf[Double])) + } else if (elementType.isArray && elementType.getComponentType == classOf[String]) { + array.map(_.asInstanceOf[Array[_]].map(_.asInstanceOf[String])) } else { throw MlUnsupportedException( s"array element type unsupported, " + diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala index df07dd42bc42..6e01090f8087 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala @@ -165,18 +165,21 @@ private[ml] object Serializer { case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => (literal.getBoolean.asInstanceOf[Object], classOf[Boolean]) case proto.Expression.Literal.LiteralTypeCase.ARRAY => - val array = literal.getArray - array.getElementType.getKindCase match { + val scalaArray = LiteralValueProtoConverter.toScalaArray(literal.getArray) + val arrayType = LiteralValueProtoConverter.getProtoArrayType(literal.getArray) + arrayType.getElementType.getKindCase match { case proto.DataType.KindCase.DOUBLE => - (parseDoubleArray(array), classOf[Array[Double]]) + (MLUtils.reconcileArray(classOf[Double], scalaArray), classOf[Array[Double]]) case proto.DataType.KindCase.STRING => - (parseStringArray(array), classOf[Array[String]]) + (MLUtils.reconcileArray(classOf[String], scalaArray), classOf[Array[String]]) case proto.DataType.KindCase.ARRAY => - array.getElementType.getArray.getElementType.getKindCase match { + arrayType.getElementType.getArray.getElementType.getKindCase match { case proto.DataType.KindCase.STRING => - (parseStringArrayArray(array), classOf[Array[Array[String]]]) + ( + MLUtils.reconcileArray(classOf[Array[String]], scalaArray), + classOf[Array[Array[String]]]) case _ => - throw MlUnsupportedException(s"Unsupported inner array $array") + throw MlUnsupportedException(s"Unsupported inner array ${literal.getArray}") } case _ => throw MlUnsupportedException(s"Unsupported array $literal") @@ -193,37 +196,6 @@ private[ml] object Serializer { } } - private def parseDoubleArray(array: proto.Expression.Literal.Array): Array[Double] = { - val values = new Array[Double](array.getElementsCount) - var i = 0 - while (i < array.getElementsCount) { - values(i) = array.getElements(i).getDouble - i += 1 - } - values - } - - private def parseStringArray(array: proto.Expression.Literal.Array): Array[String] = { - val values = new Array[String](array.getElementsCount) - var i = 0 - while (i < array.getElementsCount) { - values(i) = array.getElements(i).getString - i += 1 - } - values - } - - private def parseStringArrayArray( - array: proto.Expression.Literal.Array): Array[Array[String]] = { - val values = new Array[Array[String]](array.getElementsCount) - var i = 0 - while (i < array.getElementsCount) { - values(i) = parseStringArray(array.getElements(i).getArray) - i += 1 - } - values - } - /** * Serialize an instance of "Params" which could be estimator/model/evaluator ... * @param instance diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala index fc89580871d9..dfde32c5fc3e 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.planner import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.ToLiteralProtoOptions import org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter @@ -73,9 +74,25 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i StructField("a", IntegerType), StructField( "b", - StructType( - Seq(StructField("c", IntegerType), StructField("d", IntegerType)))))))).zipWithIndex - .foreach { case ((v, t), idx) => + StructType(Seq(StructField("c", IntegerType), StructField("d", IntegerType))))))), + (Array(true, false, true), ArrayType(BooleanType)), + (Array(1.toByte, 2.toByte, 3.toByte), ArrayType(ByteType)), + (Array(1.toShort, 2.toShort, 3.toShort), ArrayType(ShortType)), + (Array(1, 2, 3), ArrayType(IntegerType)), + (Array(1L, 2L, 3L), ArrayType(LongType)), + (Array(1.1d, 2.1d, 3.1d), ArrayType(DoubleType)), + (Array(1.1f, 2.1f, 3.1f), ArrayType(FloatType)), + (Array(Array[Int](), Array(1, 2, 3), Array(4, 5, 6)), ArrayType(ArrayType(IntegerType))), + (Array(Array(1, 2, 3), Array(4, 5, 6), Array[Int]()), ArrayType(ArrayType(IntegerType))), + ( + Array(Array(Array(Array(Array(Array(1, 2, 3)))))), + ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType))))))), + (Array(Map(1 -> 2)), ArrayType(MapType(IntegerType, IntegerType))), + (Map[String, String]("1" -> "2", "3" -> "4"), MapType(StringType, StringType)), + (Map[String, Boolean]("1" -> true, "2" -> false), MapType(StringType, BooleanType)), + (Map[Int, Int](), MapType(IntegerType, IntegerType)), + (Map(1 -> 2, 3 -> 4, 5 -> 6), MapType(IntegerType, IntegerType))).zipWithIndex.foreach { + case ((v, t), idx) => test(s"complex proto value and catalyst value conversion #$idx") { assertResult(v)( LiteralValueProtoConverter.toScalaValue( @@ -93,23 +110,18 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i Some(t), ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)))) } - } + } test("backward compatibility for array literal proto") { // Test the old way of defining arrays with elementType field and elements - val arrayProto = proto.Expression.Literal.Array - .newBuilder() - .setElementType( - proto.DataType - .newBuilder() - .setInteger(proto.DataType.Integer.newBuilder()) - .build()) - .addElements(toLiteralProto(1)) - .addElements(toLiteralProto(2)) - .addElements(toLiteralProto(3)) - .build() + val literalProto = LiteralValueProtoConverter.toLiteralProtoWithOptions( + Seq(1, 2, 3), + Some(ArrayType(IntegerType, containsNull = false)), + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + assert(!literalProto.getArray.hasDataType) + assert(literalProto.getArray.getElementsList.size == 3) + assert(literalProto.getArray.getElementType.hasInteger) - val literalProto = proto.Expression.Literal.newBuilder().setArray(arrayProto).build() val literal = LiteralExpressionProtoConverter.toCatalystExpression(literalProto) assert(literal.dataType.isInstanceOf[ArrayType]) assert(literal.dataType.asInstanceOf[ArrayType].elementType == IntegerType) @@ -125,25 +137,16 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i test("backward compatibility for map literal proto") { // Test the old way of defining maps with keyType and valueType fields - val mapProto = proto.Expression.Literal.Map - .newBuilder() - .setKeyType( - proto.DataType - .newBuilder() - .setString(proto.DataType.String.newBuilder()) - .build()) - .setValueType( - proto.DataType - .newBuilder() - .setInteger(proto.DataType.Integer.newBuilder()) - .build()) - .addKeys(toLiteralProto("a")) - .addKeys(toLiteralProto("b")) - .addValues(toLiteralProto(1)) - .addValues(toLiteralProto(2)) - .build() + val literalProto = LiteralValueProtoConverter.toLiteralProtoWithOptions( + Map[String, Int]("a" -> 1, "b" -> 2), + Some(MapType(StringType, IntegerType, valueContainsNull = false)), + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + assert(!literalProto.getMap.hasDataType) + assert(literalProto.getMap.getKeysList.size == 2) + assert(literalProto.getMap.getValuesList.size == 2) + assert(literalProto.getMap.getKeyType.hasString) + assert(literalProto.getMap.getValueType.hasInteger) - val literalProto = proto.Expression.Literal.newBuilder().setMap(mapProto).build() val literal = LiteralExpressionProtoConverter.toCatalystExpression(literalProto) assert(literal.dataType.isInstanceOf[MapType]) assert(literal.dataType.asInstanceOf[MapType].keyType == StringType) @@ -163,39 +166,25 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i test("backward compatibility for struct literal proto") { // Test the old way of defining structs with structType field and elements - val structTypeProto = proto.DataType.Struct - .newBuilder() - .addFields( - proto.DataType.StructField - .newBuilder() - .setName("a") - .setDataType(proto.DataType - .newBuilder() - .setInteger(proto.DataType.Integer.newBuilder()) - .build()) - .setNullable(true) - .build()) - .addFields( - proto.DataType.StructField - .newBuilder() - .setName("b") - .setDataType(proto.DataType - .newBuilder() - .setString(proto.DataType.String.newBuilder()) - .build()) - .setNullable(false) - .build()) - .build() - - val structProto = proto.Expression.Literal.Struct - .newBuilder() - .setStructType(proto.DataType.newBuilder().setStruct(structTypeProto).build()) - .addElements(LiteralValueProtoConverter.toLiteralProto(1)) - .addElements(LiteralValueProtoConverter.toLiteralProto("test")) - .build() + val structProto = LiteralValueProtoConverter.toLiteralProtoWithOptions( + (1, "test"), + Some( + StructType( + Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = false)))), + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + assert(!structProto.getStruct.hasDataTypeStruct) + assert(structProto.getStruct.getElementsList.size == 2) + val structTypeProto = structProto.getStruct.getStructType.getStruct + assert(structTypeProto.getFieldsList.size == 2) + assert(structTypeProto.getFieldsList.get(0).getName == "a") + assert(structTypeProto.getFieldsList.get(0).getDataType.hasInteger) + assert(structTypeProto.getFieldsList.get(1).getName == "b") + assert(structTypeProto.getFieldsList.get(1).getDataType.hasString) - val result = LiteralValueProtoConverter.toScalaStruct(structProto) - val resultType = LiteralValueProtoConverter.getProtoStructType(structProto) + val result = LiteralValueProtoConverter.toScalaStruct(structProto.getStruct) + val resultType = LiteralValueProtoConverter.getProtoStructType(structProto.getStruct) // Verify the result is a tuple with correct values assert(result.isInstanceOf[Product]) @@ -259,4 +248,70 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i assert(!structTypeProto.getFieldsList.get(1).getNullable) assert(!structTypeProto.getFieldsList.get(1).hasMetadata) } + + test("element type of array literal is set for an empty array") { + val literalProto = + toLiteralProto(Array[Int](), ArrayType(IntegerType)) + assert(literalProto.getArray.getDataType.hasElementType) + } + + test("element type of array literal is set for a non-empty array with non-inferable type") { + val literalProto = toLiteralProto(Array[String]("1", "2", "3"), ArrayType(StringType)) + assert(literalProto.getArray.getDataType.hasElementType) + } + + test("element type of array literal is not set for a non-empty array with inferable type") { + val literalProto = + toLiteralProto(Array(1, 2, 3), ArrayType(IntegerType)) + assert(!literalProto.getArray.getDataType.hasElementType) + } + + test("key and value type of map literal are set for an empty map") { + val literalProto = toLiteralProto(Map[Int, Int](), MapType(IntegerType, IntegerType)) + assert(literalProto.getMap.getDataType.hasKeyType) + assert(literalProto.getMap.getDataType.hasValueType) + } + + test("key type of map literal is set for a non-empty map with non-inferable key type") { + val literalProto = toLiteralProto( + Map[String, Int]("1" -> 1, "2" -> 2, "3" -> 3), + MapType(StringType, IntegerType)) + assert(literalProto.getMap.getDataType.hasKeyType) + assert(!literalProto.getMap.getDataType.hasValueType) + } + + test("value type of map literal is set for a non-empty map with non-inferable value type") { + val literalProto = toLiteralProto( + Map[Int, String](1 -> "1", 2 -> "2", 3 -> "3"), + MapType(IntegerType, StringType)) + assert(!literalProto.getMap.getDataType.hasKeyType) + assert(literalProto.getMap.getDataType.hasValueType) + } + + test("key and value type of map literal are not set for a non-empty map with inferable types") { + val literalProto = + toLiteralProto(Map(1 -> 2, 3 -> 4, 5 -> 6), MapType(IntegerType, IntegerType)) + assert(!literalProto.getMap.getDataType.hasKeyType) + assert(!literalProto.getMap.getDataType.hasValueType) + } + + test("an invalid array literal") { + val literalProto = proto.Expression.Literal + .newBuilder() + .setArray(proto.Expression.Literal.Array.newBuilder()) + .build() + intercept[InvalidPlanInput] { + LiteralValueProtoConverter.toScalaValue(literalProto) + } + } + + test("an invalid map literal") { + val literalProto = proto.Expression.Literal + .newBuilder() + .setMap(proto.Expression.Literal.Map.newBuilder()) + .build() + intercept[InvalidPlanInput] { + LiteralValueProtoConverter.toScalaValue(literalProto) + } + } }