Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,8 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_connect_retry",
"pyspark.sql.tests.connect.test_connect_session",
"pyspark.sql.tests.connect.test_connect_stat",
"pyspark.sql.tests.connect.test_parity_geographytype",
"pyspark.sql.tests.connect.test_parity_geometrytype",
"pyspark.sql.tests.connect.test_parity_datasources",
"pyspark.sql.tests.connect.test_parity_errors",
"pyspark.sql.tests.connect.test_parity_catalog",
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,16 @@
"<arg1> and <arg2> should be of the same length, got <arg1_length> and <arg2_length>."
]
},
"MALFORMED_GEOGRAPHY": {
"message": [
"Geography binary is malformed. Please check the data source is valid."
]
},
"MALFORMED_GEOMETRY": {
"message": [
"Geometry binary is malformed. Please check the data source is valid."
]
},
"MALFORMED_VARIANT": {
"message": [
"Variant binary is malformed. Please check the data source is valid."
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/connect/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
NullType,
NumericType,
VariantType,
GeographyType,
GeometryType,
UserDefinedType,
)
from pyspark.errors import PySparkAssertionError, PySparkValueError
Expand Down Expand Up @@ -191,6 +193,10 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
ret.array.contains_null = data_type.containsNull
elif isinstance(data_type, VariantType):
ret.variant.CopyFrom(pb2.DataType.Variant())
elif isinstance(data_type, GeometryType):
ret.geometry.srid = data_type.srid
elif isinstance(data_type, GeographyType):
ret.geography.srid = data_type.srid
elif isinstance(data_type, UserDefinedType):
json_value = data_type.jsonValue()
ret.udt.type = "udt"
Expand Down Expand Up @@ -303,6 +309,18 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType:
)
elif schema.HasField("variant"):
return VariantType()
elif schema.HasField("geometry"):
srid = schema.geometry.srid
if srid == GeometryType.MIXED_SRID:
return GeometryType("ANY")
else:
return GeometryType(srid)
elif schema.HasField("geography"):
srid = schema.geography.srid
if srid == GeographyType.MIXED_SRID:
return GeographyType("ANY")
else:
return GeographyType(srid)
elif schema.HasField("udt"):
assert schema.udt.type == "udt"
json_value = {}
Expand Down
74 changes: 74 additions & 0 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
BinaryType,
DataType,
DecimalType,
GeographyType,
Geography,
GeometryType,
Geometry,
MapType,
NullType,
Row,
Expand Down Expand Up @@ -89,6 +93,10 @@ def _need_converter(
return True
elif isinstance(dataType, VariantType):
return True
elif isinstance(dataType, GeometryType):
return True
elif isinstance(dataType, GeographyType):
return True
else:
return False

Expand Down Expand Up @@ -392,6 +400,34 @@ def convert_variant(value: Any) -> Any:

return convert_variant

elif isinstance(dataType, GeographyType):

def convert_geography(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
elif isinstance(value, Geography):
return dataType.toInternal(value)
else:
raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")

return convert_geography

elif isinstance(dataType, GeometryType):

def convert_geometry(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
elif isinstance(value, Geometry):
return dataType.toInternal(value)
else:
raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")

return convert_geometry

elif not nullable:

def convert_other(value: Any) -> Any:
Expand Down Expand Up @@ -511,6 +547,10 @@ def _need_converter(dataType: DataType) -> bool:
return True
elif isinstance(dataType, VariantType):
return True
elif isinstance(dataType, GeographyType):
return True
elif isinstance(dataType, GeometryType):
return True
else:
return False

Expand Down Expand Up @@ -719,6 +759,40 @@ def convert_variant(value: Any) -> Any:

return convert_variant

elif isinstance(dataType, GeographyType):

def convert_geography(value: Any) -> Any:
if value is None:
return None
elif (
isinstance(value, dict)
and all(key in value for key in ["wkb", "srid"])
and isinstance(value["wkb"], bytes)
and isinstance(value["srid"], int)
):
return Geography.fromWKB(value["wkb"], value["srid"])
else:
raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")

return convert_geography

elif isinstance(dataType, GeometryType):

def convert_geometry(value: Any) -> Any:
if value is None:
return None
elif (
isinstance(value, dict)
and all(key in value for key in ["wkb", "srid"])
and isinstance(value["wkb"], bytes)
and isinstance(value["srid"], int)
):
return Geometry.fromWKB(value["wkb"], value["srid"])
else:
raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")

return convert_geometry

else:
if none_on_identity:
return None
Expand Down
38 changes: 38 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_geographytype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.sql.tests.test_geographytype import GeographyTypeTestMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class GeographyTypeParityTest(GeographyTypeTestMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
import unittest
from pyspark.sql.tests.connect.test_parity_geographytype import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
38 changes: 38 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_geometrytype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.sql.tests.test_geometrytype import GeometryTypeTestMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class GeometryTypeParityTest(GeometryTypeTestMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
import unittest
from pyspark.sql.tests.connect.test_parity_geometrytype import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading