Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,8 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_connect_retry",
"pyspark.sql.tests.connect.test_connect_session",
"pyspark.sql.tests.connect.test_connect_stat",
"pyspark.sql.tests.connect.test_parity_geographytype",
"pyspark.sql.tests.connect.test_parity_geometrytype",
"pyspark.sql.tests.connect.test_parity_datasources",
"pyspark.sql.tests.connect.test_parity_errors",
"pyspark.sql.tests.connect.test_parity_catalog",
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,16 @@
"<arg1> and <arg2> should be of the same length, got <arg1_length> and <arg2_length>."
]
},
"MALFORMED_GEOGRAPHY": {
"message": [
"Geography binary is malformed. Please check the data source is valid."
]
},
"MALFORMED_GEOMETRY": {
"message": [
"Geometry binary is malformed. Please check the data source is valid."
]
},
"MALFORMED_VARIANT": {
"message": [
"Variant binary is malformed. Please check the data source is valid."
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/connect/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
NullType,
NumericType,
VariantType,
GeographyType,
GeometryType,
UserDefinedType,
)
from pyspark.errors import PySparkAssertionError, PySparkValueError
Expand Down Expand Up @@ -191,6 +193,10 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
ret.array.contains_null = data_type.containsNull
elif isinstance(data_type, VariantType):
ret.variant.CopyFrom(pb2.DataType.Variant())
elif isinstance(data_type, GeometryType):
ret.geometry.srid = data_type.srid
elif isinstance(data_type, GeographyType):
ret.geography.srid = data_type.srid
elif isinstance(data_type, UserDefinedType):
json_value = data_type.jsonValue()
ret.udt.type = "udt"
Expand Down Expand Up @@ -303,6 +309,18 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType:
)
elif schema.HasField("variant"):
return VariantType()
elif schema.HasField("geometry"):
srid = schema.geometry.srid
if srid == GeometryType.MIXED_SRID:
return GeometryType("ANY")
else:
return GeometryType(srid)
elif schema.HasField("geography"):
srid = schema.geography.srid
if srid == GeographyType.MIXED_SRID:
return GeographyType("ANY")
else:
return GeographyType(srid)
elif schema.HasField("udt"):
assert schema.udt.type == "udt"
json_value = {}
Expand Down
74 changes: 74 additions & 0 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
BinaryType,
DataType,
DecimalType,
GeographyType,
Geography,
GeometryType,
Geometry,
MapType,
NullType,
Row,
Expand Down Expand Up @@ -89,6 +93,10 @@ def _need_converter(
return True
elif isinstance(dataType, VariantType):
return True
elif isinstance(dataType, GeometryType):
return True
elif isinstance(dataType, GeographyType):
return True
else:
return False

Expand Down Expand Up @@ -392,6 +400,34 @@ def convert_variant(value: Any) -> Any:

return convert_variant

elif isinstance(dataType, GeographyType):

def convert_geography(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
elif isinstance(value, Geography):
return dataType.toInternal(value)
else:
raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")

return convert_geography

elif isinstance(dataType, GeometryType):

def convert_geometry(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
elif isinstance(value, Geometry):
return dataType.toInternal(value)
else:
raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")

return convert_geometry

elif not nullable:

def convert_other(value: Any) -> Any:
Expand Down Expand Up @@ -511,6 +547,10 @@ def _need_converter(dataType: DataType) -> bool:
return True
elif isinstance(dataType, VariantType):
return True
elif isinstance(dataType, GeographyType):
return True
elif isinstance(dataType, GeometryType):
return True
else:
return False

Expand Down Expand Up @@ -719,6 +759,40 @@ def convert_variant(value: Any) -> Any:

return convert_variant

elif isinstance(dataType, GeographyType):

def convert_geography(value: Any) -> Any:
if value is None:
return None
elif (
isinstance(value, dict)
and all(key in value for key in ["wkb", "srid"])
and isinstance(value["wkb"], bytes)
and isinstance(value["srid"], int)
):
return Geography.fromWKB(value["wkb"], value["srid"])
else:
raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")

return convert_geography

elif isinstance(dataType, GeometryType):

def convert_geometry(value: Any) -> Any:
if value is None:
return None
elif (
isinstance(value, dict)
and all(key in value for key in ["wkb", "srid"])
and isinstance(value["wkb"], bytes)
and isinstance(value["srid"], int)
):
return Geometry.fromWKB(value["wkb"], value["srid"])
else:
raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")

return convert_geometry

else:
if none_on_identity:
return None
Expand Down
38 changes: 38 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_geographytype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.sql.tests.test_geographytype import GeographyTypeTestMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class GeographyTypeParityTest(GeographyTypeTestMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
import unittest
from pyspark.sql.tests.connect.test_parity_geographytype import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
38 changes: 38 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_geometrytype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.sql.tests.test_geometrytype import GeometryTypeTestMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class GeometryTypeParityTest(GeometryTypeTestMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
import unittest
from pyspark.sql.tests.connect.test_parity_geometrytype import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading