From a7427153acc8c4225efc00d49a88fd8314a78d31 Mon Sep 17 00:00:00 2001 From: Yihong He Date: Sat, 23 Aug 2025 22:22:49 +0200 Subject: [PATCH] [SPARK-53354] Simplify LiteralValueProtoConverter.toCatalystStruct --- .../common/LiteralValueProtoConverter.scala | 145 +++++++----------- .../LiteralExpressionProtoConverter.scala | 6 +- ...LiteralExpressionProtoConverterSuite.scala | 5 +- 3 files changed, 63 insertions(+), 93 deletions(-) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index a4d8b0f2a02d..293ffe17bb4f 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -320,7 +320,7 @@ object LiteralValueProtoConverter { toCatalystArray(literal.getArray) case proto.Expression.Literal.LiteralTypeCase.STRUCT => - toCatalystStruct(literal.getStruct)._1 + toCatalystStruct(literal.getStruct) case other => throw new UnsupportedOperationException( @@ -328,9 +328,7 @@ object LiteralValueProtoConverter { } } - private def getConverter( - dataType: proto.DataType, - inferDataType: Boolean = false): proto.Expression.Literal => Any = { + private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { dataType.getKindCase match { case proto.DataType.KindCase.SHORT => v => v.getShort.toShort case proto.DataType.KindCase.INTEGER => v => v.getInteger @@ -354,20 +352,15 @@ object LiteralValueProtoConverter { case proto.DataType.KindCase.ARRAY => v => toCatalystArray(v.getArray) case proto.DataType.KindCase.MAP => v => toCatalystMap(v.getMap) case proto.DataType.KindCase.STRUCT => - if (inferDataType) { v => - val (struct, structType) = toCatalystStruct(v.getStruct, None) - LiteralValueWithDataType( - struct, - proto.DataType.newBuilder.setStruct(structType).build()) - } else { v => - toCatalystStruct(v.getStruct, Some(dataType.getStruct))._1 - } + v => toCatalystStructInternal(v.getStruct, dataType.getStruct) case _ => throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)") } } - private def getInferredDataType(literal: proto.Expression.Literal): Option[proto.DataType] = { + private def getInferredDataType( + literal: proto.Expression.Literal, + recursive: Boolean = false): Option[proto.DataType] = { if (literal.hasNull) { return Some(literal.getNull) } @@ -399,8 +392,31 @@ object LiteralValueProtoConverter { case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build()) case proto.Expression.Literal.LiteralTypeCase.STRUCT => - // The type of the fields will be inferred from the literals of the fields in the struct. - builder.setStruct(literal.getStruct.getStructType.getStruct) + if (recursive) { + val structType = literal.getStruct.getDataTypeStruct + val structData = literal.getStruct.getElementsList.asScala + val structTypeBuilder = proto.DataType.Struct.newBuilder + for ((element, field) <- structData.zip(structType.getFieldsList.asScala)) { + if (field.hasDataType) { + structTypeBuilder.addFields(field) + } else { + getInferredDataType(element, recursive = true) match { + case Some(dataType) => + val fieldBuilder = structTypeBuilder.addFieldsBuilder() + fieldBuilder.setName(field.getName) + fieldBuilder.setDataType(dataType) + fieldBuilder.setNullable(field.getNullable) + if (field.hasMetadata) { + fieldBuilder.setMetadata(field.getMetadata) + } + case None => return None + } + } + } + builder.setStruct(structTypeBuilder.build()) + } else { + builder.setStruct(proto.DataType.Struct.newBuilder.build()) + } case _ => // Not all data types support inferring the data type from the literal at the moment. // e.g. the type of DayTimeInterval contains extra information like start_field and @@ -410,13 +426,6 @@ object LiteralValueProtoConverter { Some(builder.build()) } - private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): proto.DataType = { - getInferredDataType(literal).getOrElse { - throw InvalidPlanInput( - s"Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}") - } - } - def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = { def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit tag: ClassTag[T]): Array[T] = { @@ -451,9 +460,9 @@ object LiteralValueProtoConverter { makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType)) } - def toCatalystStruct( + private def toCatalystStructInternal( struct: proto.Expression.Literal.Struct, - structTypeOpt: Option[proto.DataType.Struct] = None): (Any, proto.DataType.Struct) = { + structType: proto.DataType.Struct): Any = { def toTuple[A <: Object](data: Seq[A]): Product = { try { val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}") @@ -464,78 +473,36 @@ object LiteralValueProtoConverter { } } - if (struct.hasDataTypeStruct) { - // The new way to define and convert structs. - val (structData, structType) = if (structTypeOpt.isDefined) { - val structFields = structTypeOpt.get.getFieldsList.asScala - val structData = - struct.getElementsList.asScala.zip(structFields).map { case (element, structField) => - getConverter(structField.getDataType)(element) - } - (structData, structTypeOpt.get) - } else { - def protoStructField( - name: String, - dataType: proto.DataType, - nullable: Boolean, - metadata: Option[String]): proto.DataType.StructField = { - val builder = proto.DataType.StructField - .newBuilder() - .setName(name) - .setDataType(dataType) - .setNullable(nullable) - metadata.foreach(builder.setMetadata) - builder.build() - } - - val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala - - val structDataAndFields = struct.getElementsList.asScala.zip(dataTypeFields).map { - case (element, dataTypeField) => - if (dataTypeField.hasDataType) { - (getConverter(dataTypeField.getDataType)(element), dataTypeField) - } else { - val outerDataType = getInferredDataTypeOrThrow(element) - val (value, dataType) = - getConverter(outerDataType, inferDataType = true)(element) match { - case LiteralValueWithDataType(value, dataType) => (value, dataType) - case value => (value, outerDataType) - } - ( - value, - protoStructField( - dataTypeField.getName, - dataType, - dataTypeField.getNullable, - if (dataTypeField.hasMetadata) Some(dataTypeField.getMetadata) else None)) - } - } + val elements = struct.getElementsList.asScala + val dataTypes = structType.getFieldsList.asScala.map(_.getDataType) + val structData = elements + .zip(dataTypes) + .map { case (element, dataType) => + getConverter(dataType)(element) + } + .asInstanceOf[scala.collection.Seq[Object]] + .toSeq - val structType = proto.DataType.Struct - .newBuilder() - .addAllFields(structDataAndFields.map(_._2).asJava) - .build() + toTuple(structData) + } - (structDataAndFields.map(_._1), structType) + def getProtoStructType(struct: proto.Expression.Literal.Struct): proto.DataType.Struct = { + if (struct.hasDataTypeStruct) { + val literal = proto.Expression.Literal.newBuilder().setStruct(struct).build() + getInferredDataType(literal, recursive = true) match { + case Some(dataType) => dataType.getStruct + case None => throw InvalidPlanInput("Cannot infer data type from this struct literal.") } - (toTuple(structData.toSeq.asInstanceOf[Seq[Object]]), structType) } else if (struct.hasStructType) { - // For backward compatibility, we still support the old way to define and convert structs. - val elements = struct.getElementsList.asScala - val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType) - val structData = elements - .zip(dataTypes) - .map { case (element, dataType) => - getConverter(dataType)(element) - } - .asInstanceOf[scala.collection.Seq[Object]] - .toSeq - - (toTuple(structData), struct.getStructType.getStruct) + // For backward compatibility, we still support the old way to + // define and convert struct types. + struct.getStructType.getStruct } else { throw InvalidPlanInput("Data type information is missing in the struct literal.") } } - private case class LiteralValueWithDataType(value: Any, dataType: proto.DataType) + def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = { + toCatalystStructInternal(struct, getProtoStructType(struct)) + } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala index 10f046a57da9..f4c56d461bd2 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala @@ -117,9 +117,11 @@ object LiteralExpressionProtoConverter { DataTypeProtoConverter.toCatalystType(lit.getMap.getValueType))) case proto.Expression.Literal.LiteralTypeCase.STRUCT => - val (structData, structType) = LiteralValueProtoConverter.toCatalystStruct(lit.getStruct) + val structData = LiteralValueProtoConverter.toCatalystStruct(lit.getStruct) val dataType = DataTypeProtoConverter.toCatalystType( - proto.DataType.newBuilder.setStruct(structType).build()) + proto.DataType.newBuilder + .setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct)) + .build()) val convert = CatalystTypeConverters.createToCatalystConverter(dataType) expressions.Literal(convert(structData), dataType) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala index 559984e47cf8..71fcd2b39492 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala @@ -99,7 +99,8 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i .addElements(LiteralValueProtoConverter.toLiteralProto("test")) .build() - val (result, resultType) = LiteralValueProtoConverter.toCatalystStruct(structProto) + val result = LiteralValueProtoConverter.toCatalystStruct(structProto) + val resultType = LiteralValueProtoConverter.getProtoStructType(structProto) // Verify the result is a tuple with correct values assert(result.isInstanceOf[Product]) @@ -156,7 +157,7 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i assert(!structFields.get(1).getNullable) assert(!structFields.get(1).hasMetadata) - val (_, structTypeProto) = LiteralValueProtoConverter.toCatalystStruct(literalProto.getStruct) + val structTypeProto = LiteralValueProtoConverter.getProtoStructType(literalProto.getStruct) assert(structTypeProto.getFieldsList.get(0).getNullable) assert(structTypeProto.getFieldsList.get(0).hasMetadata) assert(structTypeProto.getFieldsList.get(0).getMetadata == """{"key":"value"}""")