diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala new file mode 100644 index 000000000000..2016a84ac5a3 --- /dev/null +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException} +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.types._ + +class GeographyConnectDataFrameSuite extends QueryTest with RemoteSparkSession { + + private val point1: Array[Byte] = "010100000000000000000031400000000000001C40" + .grouped(2) + .map(Integer.parseInt(_, 16).toByte) + .toArray + private val point2: Array[Byte] = "010100000000000000000035400000000000001E40" + .grouped(2) + .map(Integer.parseInt(_, 16).toByte) + .toArray + + test("decode geography value: SRID schema does not match input SRID data schema") { + val geography = Geography.fromWKB(point1, 0) + + val seq = Seq((geography, 1)) + checkError( + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(seq).collect() + }, + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326")) + + import testImplicits._ + checkError( + exception = intercept[SparkRuntimeException] { + Seq(geography).toDF().collect() + }, + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326")) + } + + test("decode geography value: mixed SRID schema is provided") { + val schema = StructType(Seq(StructField("col1", GeographyType("ANY"), nullable = false))) + val expectedResult = + Seq(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 4326))) + + val javaList = java.util.Arrays + .asList(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 4326))) + val resultJavaListDF = spark.createDataFrame(javaList, schema) + checkAnswer(resultJavaListDF, expectedResult) + + // Test that unsupported SRID with mixed schema will throw an error. + val invalidData = + java.util.Arrays + .asList(Row(Geography.fromWKB(point1, 1)), Row(Geography.fromWKB(point2, 4326))) + checkError( + exception = intercept[SparkIllegalArgumentException] { + spark.createDataFrame(invalidData, schema).collect() + }, + condition = "ST_INVALID_SRID_VALUE", + parameters = Map("srid" -> "1")) + } + + test("createDataFrame APIs with Geography.fromWKB") { + val geography1 = Geography.fromWKB(point1, 4326) + val geography2 = Geography.fromWKB(point2) + + val seq = Seq((geography1, 1), (geography2, 2), (null, 3)) + val dfFromSeq = spark.createDataFrame(seq) + checkAnswer(dfFromSeq, Seq(Row(geography1, 1), Row(geography2, 2), Row(null, 3))) + + val schema = StructType(Seq(StructField("geography", GeographyType(4326), nullable = true))) + + val javaList = java.util.Arrays.asList(Row(geography1), Row(geography2), Row(null)) + val dfFromJavaList = spark.createDataFrame(javaList, schema) + checkAnswer(dfFromJavaList, Seq(Row(geography1), Row(geography2), Row(null))) + + import testImplicits._ + val implicitDf = Seq(geography1, geography2, null).toDF() + checkAnswer(implicitDf, Seq(Row(geography1), Row(geography2), Row(null))) + } + + test("encode geography type") { + // POINT (17 7) + val wkb = "010100000000000000000031400000000000001C40" + val df = spark.sql(s"SELECT ST_GeogFromWKB(X'$wkb')") + val point = wkb.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + val expectedGeog = Geography.fromWKB(point, 4326) + checkAnswer(df, Seq(Row(expectedGeog))) + } +} diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala new file mode 100644 index 000000000000..1450ac54184b --- /dev/null +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException} +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.types._ + +class GeometryConnectDataFrameSuite extends QueryTest with RemoteSparkSession { + + private val point1: Array[Byte] = "010100000000000000000031400000000000001C40" + .grouped(2) + .map(Integer.parseInt(_, 16).toByte) + .toArray + private val point2: Array[Byte] = "010100000000000000000035400000000000001E40" + .grouped(2) + .map(Integer.parseInt(_, 16).toByte) + .toArray + + test("decode geometry value: SRID schema does not match input SRID data schema") { + val geometry = Geometry.fromWKB(point1, 4326) + + val seq = Seq((geometry, 1)) + checkError( + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(seq).collect() + }, + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "0")) + + import testImplicits._ + checkError( + exception = intercept[SparkRuntimeException] { + Seq(geometry).toDF().collect() + }, + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "0")) + } + + test("decode geometry value: mixed SRID schema is provided") { + val schema = StructType(Seq(StructField("col1", GeometryType("ANY"), nullable = false))) + val expectedResult = + Seq(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326))) + + val javaList = java.util.Arrays + .asList(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326))) + val resultJavaListDF = spark.createDataFrame(javaList, schema) + checkAnswer(resultJavaListDF, expectedResult) + + // Test that unsupported SRID with mixed schema will throw an error. + val invalidData = + java.util.Arrays + .asList(Row(Geometry.fromWKB(point1, 1)), Row(Geometry.fromWKB(point2, 4326))) + checkError( + exception = intercept[SparkIllegalArgumentException] { + spark.createDataFrame(invalidData, schema).collect() + }, + condition = "ST_INVALID_SRID_VALUE", + parameters = Map("srid" -> "1")) + } + + test("createDataFrame APIs with Geometry.fromWKB") { + val geometry1 = Geometry.fromWKB(point1, 0) + val geometry2 = Geometry.fromWKB(point2, 0) + + // 1. Test createDataFrame with Seq of Geometry objects + val seq = Seq((geometry1, 1), (geometry2, 2), (null, 3)) + val dfFromSeq = spark.createDataFrame(seq) + checkAnswer(dfFromSeq, Seq(Row(geometry1, 1), Row(geometry2, 2), Row(null, 3))) + + // 2. Test createDataFrame with RDD of Rows and StructType schema + val geometry3 = Geometry.fromWKB(point1, 4326) + val geometry4 = Geometry.fromWKB(point2, 4326) + val schema = StructType(Seq(StructField("geometry", GeometryType(4326), nullable = true))) + + // 3. Test createDataFrame with Java List of Rows and StructType schema + val javaList = java.util.Arrays.asList(Row(geometry3), Row(geometry4), Row(null)) + val dfFromJavaList = spark.createDataFrame(javaList, schema) + checkAnswer(dfFromJavaList, Seq(Row(geometry3), Row(geometry4), Row(null))) + + // 4. Implicit conversion from Seq to DF + import testImplicits._ + val implicitDf = Seq(geometry1, geometry2, null).toDF() + checkAnswer(implicitDf, Seq(Row(geometry1), Row(geometry2), Row(null))) + } + + test("encode geometry type") { + // POINT (17 7) + val wkb = "010100000000000000000031400000000000001C40" + val df = spark.sql(s"SELECT ST_GeomFromWKB(X'$wkb')") + val point = wkb.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + val expectedGeom = Geometry.fromWKB(point, 0) + checkAnswer(df, Seq(Row(expectedGeom))) + } +} diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index bc840df5c3fa..d24369ff5fc7 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum import org.apache.spark.sql.connect.test.ConnectFunSuite -import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, Geography, Geometry, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType} import org.apache.spark.unsafe.types.VariantVal import org.apache.spark.util.{MaybeNull, SparkStringUtils} @@ -263,6 +263,102 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { assert(inspector.numBatches == 1) } + test("geography round trip") { + val point1 = "010100000000000000000031400000000000001C40" + .grouped(2) + .map(Integer.parseInt(_, 16).toByte) + .toArray + val point2 = "010100000000000000000035400000000000001E40" + .grouped(2) + .map(Integer.parseInt(_, 16).toByte) + .toArray + + val geographyEncoder = toRowEncoder(new StructType().add("g", "geography(4326)")) + roundTripAndCheckIdentical(geographyEncoder) { () => + val maybeNull = MaybeNull(7) + Iterator.tabulate(101)(i => Row(maybeNull(Geography.fromWKB(point1, 4326)))) + } + + val nestedGeographyEncoder = toRowEncoder( + new StructType() + .add( + "s", + new StructType() + .add("i1", "int") + .add("g0", "geography(4326)") + .add("i2", "int") + .add("g4326", "geography(4326)")) + .add("a", "array") + .add("m", "map")) + + roundTripAndCheckIdentical(nestedGeographyEncoder) { () => + val maybeNull5 = MaybeNull(5) + val maybeNull7 = MaybeNull(7) + val maybeNull11 = MaybeNull(11) + val maybeNull13 = MaybeNull(13) + val maybeNull17 = MaybeNull(17) + Iterator + .tabulate(100)(i => + Row( + maybeNull5( + Row( + i, + maybeNull7(Geography.fromWKB(point1)), + i + 1, + maybeNull11(Geography.fromWKB(point2, 4326)))), + maybeNull7((0 until 10).map(j => Geography.fromWKB(point2, 0))), + maybeNull13(Map((i.toString, maybeNull17(Geography.fromWKB(point1, 4326))))))) + } + } + + test("geometry round trip") { + val point1 = "010100000000000000000031400000000000001C40" + .grouped(2) + .map(Integer.parseInt(_, 16).toByte) + .toArray + val point2 = "010100000000000000000035400000000000001E40" + .grouped(2) + .map(Integer.parseInt(_, 16).toByte) + .toArray + + val geometryEncoder = toRowEncoder(new StructType().add("g", "geometry(0)")) + roundTripAndCheckIdentical(geometryEncoder) { () => + val maybeNull = MaybeNull(7) + Iterator.tabulate(101)(i => Row(maybeNull(Geometry.fromWKB(point1, 0)))) + } + + val nestedGeometryEncoder = toRowEncoder( + new StructType() + .add( + "s", + new StructType() + .add("i1", "int") + .add("g0", "geometry(0)") + .add("i2", "int") + .add("g4326", "geometry(4326)")) + .add("a", "array") + .add("m", "map")) + + roundTripAndCheckIdentical(nestedGeometryEncoder) { () => + val maybeNull5 = MaybeNull(5) + val maybeNull7 = MaybeNull(7) + val maybeNull11 = MaybeNull(11) + val maybeNull13 = MaybeNull(13) + val maybeNull17 = MaybeNull(17) + Iterator + .tabulate(100)(i => + Row( + maybeNull5( + Row( + i, + maybeNull7(Geometry.fromWKB(point1, 0)), + i + 1, + maybeNull11(Geometry.fromWKB(point2, 4326)))), + maybeNull7((0 until 10).map(j => Geometry.fromWKB(point2, 0))), + maybeNull13(Map((i.toString, maybeNull17(Geometry.fromWKB(point1, 4326))))))) + } + } + test("variant round trip") { val variantEncoder = toRowEncoder(new StructType().add("v", "variant")) roundTripAndCheckIdentical(variantEncoder) { () => diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 7597a0ceeb8c..8d5811dda8f3 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -341,6 +341,14 @@ object ArrowDeserializers { } } + case (_: GeometryEncoder, StructVectors(struct, vectors)) => + val gdser = new GeometryArrowSerDe + gdser.createDeserializer(struct, vectors, timeZoneId) + + case (_: GeographyEncoder, StructVectors(struct, vectors)) => + val gdser = new GeographyArrowSerDe + gdser.createDeserializer(struct, vectors, timeZoneId) + case (VariantEncoder, StructVectors(struct, vectors)) => assert(vectors.exists(_.getName == "value")) assert( diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala index 5b1539e39f4f..2430c2bbc86f 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala @@ -41,6 +41,22 @@ private[arrow] object ArrowEncoderUtils { def unsupportedCollectionType(cls: Class[_]): Nothing = { throw new RuntimeException(s"Unsupported collection type: $cls") } + + def assertMetadataPresent( + vectors: Seq[FieldVector], + expectedVectors: Seq[String], + expectedMetadata: Seq[(String, String)]): Unit = { + expectedVectors.foreach { vectorName => + assert(vectors.exists(_.getName == vectorName)) + } + + expectedVectors.zip(expectedMetadata).foreach { case (vectorName, (key, value)) => + assert( + vectors.exists(field => + field.getName == vectorName && field.getField.getMetadata + .containsKey(key) && field.getField.getMetadata.get(key) == value)) + } + } } private[arrow] object StructVectors { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index 4acb11f014d1..73c9a991ab6a 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -487,6 +487,14 @@ object ArrowSerializer { extractor = (v: Any) => v.asInstanceOf[VariantVal].getMetadata, serializerFor(BinaryEncoder, struct.getChild("metadata"))))) + case (_: GeographyEncoder, StructVectors(struct, vectors)) => + val gser = new GeographyArrowSerDe + gser.createSerializer(struct, vectors) + + case (_: GeometryEncoder, StructVectors(struct, vectors)) => + val gser = new GeometryArrowSerDe + gser.createSerializer(struct, vectors) + case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => structSerializerFor(fields, struct, vectors) { (field, _) => val getter = methodLookup.findVirtual( @@ -585,12 +593,14 @@ object ArrowSerializer { } } - private class StructFieldSerializer(val extractor: Any => Any, val serializer: Serializer) { + private[arrow] class StructFieldSerializer( + val extractor: Any => Any, + val serializer: Serializer) { def write(index: Int, value: Any): Unit = serializer.write(index, extractor(value)) def writeNull(index: Int): Unit = serializer.write(index, null) } - private class StructSerializer( + private[arrow] class StructSerializer( struct: StructVector, fieldSerializers: Seq[StructFieldSerializer]) extends Serializer { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala new file mode 100644 index 000000000000..443523ef02cd --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.arrow + +import org.apache.arrow.vector.FieldVector +import org.apache.arrow.vector.complex.StructVector + +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, PrimitiveIntEncoder} +import org.apache.spark.sql.errors.CompilationErrors +import org.apache.spark.sql.types.{Geography, Geometry} + +abstract class GeospatialArrowSerDe[T](typeName: String) { + + def createDeserializer( + struct: StructVector, + vectors: Seq[FieldVector], + timeZoneId: String): ArrowDeserializers.StructFieldSerializer[T] = { + assertMetadataPresent(vectors) + val wkbDecoder = ArrowDeserializers.deserializerFor( + BinaryEncoder, + vectors + .find(_.getName == "wkb") + .getOrElse(throw CompilationErrors.columnNotFoundError("wkb")), + timeZoneId) + val sridDecoder = ArrowDeserializers.deserializerFor( + PrimitiveIntEncoder, + vectors + .find(_.getName == "srid") + .getOrElse(throw CompilationErrors.columnNotFoundError("srid")), + timeZoneId) + new ArrowDeserializers.StructFieldSerializer[T](struct) { + override def value(i: Int): T = createInstance(wkbDecoder.get(i), sridDecoder.get(i)) + } + } + + def createSerializer( + struct: StructVector, + vectors: Seq[FieldVector]): ArrowSerializer.StructSerializer = { + assertMetadataPresent(vectors) + new ArrowSerializer.StructSerializer( + struct, + Seq( + new ArrowSerializer.StructFieldSerializer( + extractor = (v: Any) => extractSrid(v), + ArrowSerializer.serializerFor(PrimitiveIntEncoder, struct.getChild("srid"))), + new ArrowSerializer.StructFieldSerializer( + extractor = (v: Any) => extractBytes(v), + ArrowSerializer.serializerFor(BinaryEncoder, struct.getChild("wkb"))))) + } + + private def assertMetadataPresent(vectors: Seq[FieldVector]): Unit = { + assert(vectors.exists(_.getName == "srid")) + assert( + vectors.exists(field => + field.getName == "wkb" && field.getField.getMetadata + .containsKey(typeName) && field.getField.getMetadata.get(typeName) == "true")) + } + + protected def createInstance(wkb: Any, srid: Any): T + protected def extractSrid(value: Any): Int + protected def extractBytes(value: Any): Array[Byte] +} + +// Geography-specific implementation +class GeographyArrowSerDe extends GeospatialArrowSerDe[Geography]("geography") { + override protected def createInstance(wkb: Any, srid: Any): Geography = + Geography.fromWKB(wkb.asInstanceOf[Array[Byte]], srid.asInstanceOf[Int]) + + override protected def extractSrid(value: Any): Int = + value.asInstanceOf[Geography].getSrid + + override protected def extractBytes(value: Any): Array[Byte] = + value.asInstanceOf[Geography].getBytes +} + +// Geometry-specific implementation +class GeometryArrowSerDe extends GeospatialArrowSerDe[Geometry]("geometry") { + override protected def createInstance(wkb: Any, srid: Any): Geometry = + Geometry.fromWKB(wkb.asInstanceOf[Array[Byte]], srid.asInstanceOf[Int]) + + override protected def extractSrid(value: Any): Int = + value.asInstanceOf[Geometry].getSrid + + override protected def extractBytes(value: Any): Array[Byte] = + value.asInstanceOf[Geometry].getBytes +} diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index 419cc8e082af..ac69f084c307 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -71,6 +71,21 @@ object DataTypeProtoConverter { case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap) case proto.DataType.KindCase.VARIANT => VariantType + case proto.DataType.KindCase.GEOMETRY => + val srid = t.getGeometry.getSrid + if (srid == GeometryType.MIXED_SRID) { + GeometryType("ANY") + } else { + GeometryType(srid) + } + case proto.DataType.KindCase.GEOGRAPHY => + val srid = t.getGeography.getSrid + if (srid == GeographyType.MIXED_SRID) { + GeographyType("ANY") + } else { + GeographyType(srid) + } + case proto.DataType.KindCase.UDT => toCatalystUDT(t.getUdt) case _ => @@ -307,6 +322,26 @@ object DataTypeProtoConverter { .build()) .build() + case g: GeographyType => + proto.DataType + .newBuilder() + .setGeography( + proto.DataType.Geography + .newBuilder() + .setSrid(g.srid) + .build()) + .build() + + case g: GeometryType => + proto.DataType + .newBuilder() + .setGeometry( + proto.DataType.Geometry + .newBuilder() + .setSrid(g.srid) + .build()) + .build() + case VariantType => ProtoDataTypes.VariantType case pyudt: PythonUserDefinedType =>