Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import scala.collection.immutable.Seq

import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
import org.apache.spark.sql.types._

class GeographyConnectDataFrameSuite extends QueryTest with RemoteSparkSession {

private val point1: Array[Byte] = "010100000000000000000031400000000000001C40"
.grouped(2)
.map(Integer.parseInt(_, 16).toByte)
.toArray
private val point2: Array[Byte] = "010100000000000000000035400000000000001E40"
.grouped(2)
.map(Integer.parseInt(_, 16).toByte)
.toArray

test("decode geography value: SRID schema does not match input SRID data schema") {
val geography = Geography.fromWKB(point1, 0)

val seq = Seq((geography, 1))
checkError(
exception = intercept[SparkRuntimeException] {
spark.createDataFrame(seq).collect()
},
condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326"))

import testImplicits._
checkError(
exception = intercept[SparkRuntimeException] {
Seq(geography).toDF().collect()
},
condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326"))
}

test("decode geography value: mixed SRID schema is provided") {
val schema = StructType(Seq(StructField("col1", GeographyType("ANY"), nullable = false)))
val expectedResult =
Seq(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 4326)))

val javaList = java.util.Arrays
.asList(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 4326)))
val resultJavaListDF = spark.createDataFrame(javaList, schema)
checkAnswer(resultJavaListDF, expectedResult)

// Test that unsupported SRID with mixed schema will throw an error.
val invalidData =
java.util.Arrays
.asList(Row(Geography.fromWKB(point1, 1)), Row(Geography.fromWKB(point2, 4326)))
checkError(
exception = intercept[SparkIllegalArgumentException] {
spark.createDataFrame(invalidData, schema).collect()
},
condition = "ST_INVALID_SRID_VALUE",
parameters = Map("srid" -> "1"))
}

test("createDataFrame APIs with Geography.fromWKB") {
val geography1 = Geography.fromWKB(point1, 4326)
val geography2 = Geography.fromWKB(point2)

val seq = Seq((geography1, 1), (geography2, 2), (null, 3))
val dfFromSeq = spark.createDataFrame(seq)
checkAnswer(dfFromSeq, Seq(Row(geography1, 1), Row(geography2, 2), Row(null, 3)))

val schema = StructType(Seq(StructField("geography", GeographyType(4326), nullable = true)))

val javaList = java.util.Arrays.asList(Row(geography1), Row(geography2), Row(null))
val dfFromJavaList = spark.createDataFrame(javaList, schema)
checkAnswer(dfFromJavaList, Seq(Row(geography1), Row(geography2), Row(null)))

import testImplicits._
val implicitDf = Seq(geography1, geography2, null).toDF()
checkAnswer(implicitDf, Seq(Row(geography1), Row(geography2), Row(null)))
}

test("encode geography type") {
// POINT (17 7)
val wkb = "010100000000000000000031400000000000001C40"
val df = spark.sql(s"SELECT ST_GeogFromWKB(X'$wkb')")
val point = wkb.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
val expectedGeog = Geography.fromWKB(point, 4326)
checkAnswer(df, Seq(Row(expectedGeog)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import scala.collection.immutable.Seq

import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
import org.apache.spark.sql.types._

class GeometryConnectDataFrameSuite extends QueryTest with RemoteSparkSession {

private val point1: Array[Byte] = "010100000000000000000031400000000000001C40"
.grouped(2)
.map(Integer.parseInt(_, 16).toByte)
.toArray
private val point2: Array[Byte] = "010100000000000000000035400000000000001E40"
.grouped(2)
.map(Integer.parseInt(_, 16).toByte)
.toArray

test("decode geometry value: SRID schema does not match input SRID data schema") {
val geometry = Geometry.fromWKB(point1, 4326)

val seq = Seq((geometry, 1))
checkError(
exception = intercept[SparkRuntimeException] {
spark.createDataFrame(seq).collect()
},
condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "0"))

import testImplicits._
checkError(
exception = intercept[SparkRuntimeException] {
Seq(geometry).toDF().collect()
},
condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "0"))
}

test("decode geometry value: mixed SRID schema is provided") {
val schema = StructType(Seq(StructField("col1", GeometryType("ANY"), nullable = false)))
val expectedResult =
Seq(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326)))

val javaList = java.util.Arrays
.asList(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326)))
val resultJavaListDF = spark.createDataFrame(javaList, schema)
checkAnswer(resultJavaListDF, expectedResult)

// Test that unsupported SRID with mixed schema will throw an error.
val invalidData =
java.util.Arrays
.asList(Row(Geometry.fromWKB(point1, 1)), Row(Geometry.fromWKB(point2, 4326)))
checkError(
exception = intercept[SparkIllegalArgumentException] {
spark.createDataFrame(invalidData, schema).collect()
},
condition = "ST_INVALID_SRID_VALUE",
parameters = Map("srid" -> "1"))
}

test("createDataFrame APIs with Geometry.fromWKB") {
val geometry1 = Geometry.fromWKB(point1, 0)
val geometry2 = Geometry.fromWKB(point2, 0)

// 1. Test createDataFrame with Seq of Geometry objects
val seq = Seq((geometry1, 1), (geometry2, 2), (null, 3))
val dfFromSeq = spark.createDataFrame(seq)
checkAnswer(dfFromSeq, Seq(Row(geometry1, 1), Row(geometry2, 2), Row(null, 3)))

// 2. Test createDataFrame with RDD of Rows and StructType schema
val geometry3 = Geometry.fromWKB(point1, 4326)
val geometry4 = Geometry.fromWKB(point2, 4326)
val schema = StructType(Seq(StructField("geometry", GeometryType(4326), nullable = true)))

// 3. Test createDataFrame with Java List of Rows and StructType schema
val javaList = java.util.Arrays.asList(Row(geometry3), Row(geometry4), Row(null))
val dfFromJavaList = spark.createDataFrame(javaList, schema)
checkAnswer(dfFromJavaList, Seq(Row(geometry3), Row(geometry4), Row(null)))

// 4. Implicit conversion from Seq to DF
import testImplicits._
val implicitDf = Seq(geometry1, geometry2, null).toDF()
checkAnswer(implicitDf, Seq(Row(geometry1), Row(geometry2), Row(null)))
}

test("encode geometry type") {
// POINT (17 7)
val wkb = "010100000000000000000031400000000000001C40"
val df = spark.sql(s"SELECT ST_GeomFromWKB(X'$wkb')")
val point = wkb.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
val expectedGeom = Geometry.fromWKB(point, 0)
checkAnswer(df, Seq(Row(expectedGeom)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._
import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
import org.apache.spark.sql.connect.test.ConnectFunSuite
import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType}
import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, Geography, Geometry, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType}
import org.apache.spark.unsafe.types.VariantVal
import org.apache.spark.util.{MaybeNull, SparkStringUtils}

Expand Down Expand Up @@ -263,6 +263,102 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
assert(inspector.numBatches == 1)
}

test("geography round trip") {
val point1 = "010100000000000000000031400000000000001C40"
.grouped(2)
.map(Integer.parseInt(_, 16).toByte)
.toArray
val point2 = "010100000000000000000035400000000000001E40"
.grouped(2)
.map(Integer.parseInt(_, 16).toByte)
.toArray

val geographyEncoder = toRowEncoder(new StructType().add("g", "geography(4326)"))
roundTripAndCheckIdentical(geographyEncoder) { () =>
val maybeNull = MaybeNull(7)
Iterator.tabulate(101)(i => Row(maybeNull(Geography.fromWKB(point1, 4326))))
}

val nestedGeographyEncoder = toRowEncoder(
new StructType()
.add(
"s",
new StructType()
.add("i1", "int")
.add("g0", "geography(4326)")
.add("i2", "int")
.add("g4326", "geography(4326)"))
.add("a", "array<geography(4326)>")
.add("m", "map<string, geography(ANY)>"))

roundTripAndCheckIdentical(nestedGeographyEncoder) { () =>
val maybeNull5 = MaybeNull(5)
val maybeNull7 = MaybeNull(7)
val maybeNull11 = MaybeNull(11)
val maybeNull13 = MaybeNull(13)
val maybeNull17 = MaybeNull(17)
Iterator
.tabulate(100)(i =>
Row(
maybeNull5(
Row(
i,
maybeNull7(Geography.fromWKB(point1)),
i + 1,
maybeNull11(Geography.fromWKB(point2, 4326)))),
maybeNull7((0 until 10).map(j => Geography.fromWKB(point2, 0))),
maybeNull13(Map((i.toString, maybeNull17(Geography.fromWKB(point1, 4326)))))))
}
}

test("geometry round trip") {
val point1 = "010100000000000000000031400000000000001C40"
.grouped(2)
.map(Integer.parseInt(_, 16).toByte)
.toArray
val point2 = "010100000000000000000035400000000000001E40"
.grouped(2)
.map(Integer.parseInt(_, 16).toByte)
.toArray

val geometryEncoder = toRowEncoder(new StructType().add("g", "geometry(0)"))
roundTripAndCheckIdentical(geometryEncoder) { () =>
val maybeNull = MaybeNull(7)
Iterator.tabulate(101)(i => Row(maybeNull(Geometry.fromWKB(point1, 0))))
}

val nestedGeometryEncoder = toRowEncoder(
new StructType()
.add(
"s",
new StructType()
.add("i1", "int")
.add("g0", "geometry(0)")
.add("i2", "int")
.add("g4326", "geometry(4326)"))
.add("a", "array<geometry(0)>")
.add("m", "map<string, geometry(ANY)>"))

roundTripAndCheckIdentical(nestedGeometryEncoder) { () =>
val maybeNull5 = MaybeNull(5)
val maybeNull7 = MaybeNull(7)
val maybeNull11 = MaybeNull(11)
val maybeNull13 = MaybeNull(13)
val maybeNull17 = MaybeNull(17)
Iterator
.tabulate(100)(i =>
Row(
maybeNull5(
Row(
i,
maybeNull7(Geometry.fromWKB(point1, 0)),
i + 1,
maybeNull11(Geometry.fromWKB(point2, 4326)))),
maybeNull7((0 until 10).map(j => Geometry.fromWKB(point2, 0))),
maybeNull13(Map((i.toString, maybeNull17(Geometry.fromWKB(point1, 4326)))))))
}
}

test("variant round trip") {
val variantEncoder = toRowEncoder(new StructType().add("v", "variant"))
roundTripAndCheckIdentical(variantEncoder) { () =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,14 @@ object ArrowDeserializers {
}
}

case (_: GeometryEncoder, StructVectors(struct, vectors)) =>
val gdser = new GeometryArrowSerDe
gdser.createDeserializer(struct, vectors, timeZoneId)

case (_: GeographyEncoder, StructVectors(struct, vectors)) =>
val gdser = new GeographyArrowSerDe
gdser.createDeserializer(struct, vectors, timeZoneId)

case (VariantEncoder, StructVectors(struct, vectors)) =>
assert(vectors.exists(_.getName == "value"))
assert(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ private[arrow] object ArrowEncoderUtils {
def unsupportedCollectionType(cls: Class[_]): Nothing = {
throw new RuntimeException(s"Unsupported collection type: $cls")
}

def assertMetadataPresent(
vectors: Seq[FieldVector],
expectedVectors: Seq[String],
expectedMetadata: Seq[(String, String)]): Unit = {
expectedVectors.foreach { vectorName =>
assert(vectors.exists(_.getName == vectorName))
}

expectedVectors.zip(expectedMetadata).foreach { case (vectorName, (key, value)) =>
assert(
vectors.exists(field =>
field.getName == vectorName && field.getField.getMetadata
.containsKey(key) && field.getField.getMetadata.get(key) == value))
}
}
}

private[arrow] object StructVectors {
Expand Down
Loading