diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 07ac4c76b91a..aa8ca58a5a75 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1114,6 +1114,8 @@ def __hash__(self): "pyspark.sql.tests.connect.test_connect_retry", "pyspark.sql.tests.connect.test_connect_session", "pyspark.sql.tests.connect.test_connect_stat", + "pyspark.sql.tests.connect.test_parity_geographytype", + "pyspark.sql.tests.connect.test_parity_geometrytype", "pyspark.sql.tests.connect.test_parity_datasources", "pyspark.sql.tests.connect.test_parity_errors", "pyspark.sql.tests.connect.test_parity_catalog", diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index d169e6293a1b..51bbdd862516 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -549,6 +549,16 @@ " and should be of the same length, got and ." ] }, + "MALFORMED_GEOGRAPHY": { + "message": [ + "Geography binary is malformed. Please check the data source is valid." + ] + }, + "MALFORMED_GEOMETRY": { + "message": [ + "Geometry binary is malformed. Please check the data source is valid." + ] + }, "MALFORMED_VARIANT": { "message": [ "Variant binary is malformed. Please check the data source is valid." diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 7e8f76861079..d3352b618d7c 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -50,6 +50,8 @@ NullType, NumericType, VariantType, + GeographyType, + GeometryType, UserDefinedType, ) from pyspark.errors import PySparkAssertionError, PySparkValueError @@ -191,6 +193,10 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: ret.array.contains_null = data_type.containsNull elif isinstance(data_type, VariantType): ret.variant.CopyFrom(pb2.DataType.Variant()) + elif isinstance(data_type, GeometryType): + ret.geometry.srid = data_type.srid + elif isinstance(data_type, GeographyType): + ret.geography.srid = data_type.srid elif isinstance(data_type, UserDefinedType): json_value = data_type.jsonValue() ret.udt.type = "udt" @@ -303,6 +309,18 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: ) elif schema.HasField("variant"): return VariantType() + elif schema.HasField("geometry"): + srid = schema.geometry.srid + if srid == GeometryType.MIXED_SRID: + return GeometryType("ANY") + else: + return GeometryType(srid) + elif schema.HasField("geography"): + srid = schema.geography.srid + if srid == GeographyType.MIXED_SRID: + return GeographyType("ANY") + else: + return GeographyType(srid) elif schema.HasField("udt"): assert schema.udt.type == "udt" json_value = {} diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index a8f621277a0a..f73727d1d534 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -28,6 +28,10 @@ BinaryType, DataType, DecimalType, + GeographyType, + Geography, + GeometryType, + Geometry, MapType, NullType, Row, @@ -89,6 +93,10 @@ def _need_converter( return True elif isinstance(dataType, VariantType): return True + elif isinstance(dataType, GeometryType): + return True + elif isinstance(dataType, GeographyType): + return True else: return False @@ -392,6 +400,34 @@ def convert_variant(value: Any) -> Any: return convert_variant + elif isinstance(dataType, GeographyType): + + def convert_geography(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + elif isinstance(value, Geography): + return dataType.toInternal(value) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY") + + return convert_geography + + elif isinstance(dataType, GeometryType): + + def convert_geometry(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + elif isinstance(value, Geometry): + return dataType.toInternal(value) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOMETRY") + + return convert_geometry + elif not nullable: def convert_other(value: Any) -> Any: @@ -511,6 +547,10 @@ def _need_converter(dataType: DataType) -> bool: return True elif isinstance(dataType, VariantType): return True + elif isinstance(dataType, GeographyType): + return True + elif isinstance(dataType, GeometryType): + return True else: return False @@ -719,6 +759,40 @@ def convert_variant(value: Any) -> Any: return convert_variant + elif isinstance(dataType, GeographyType): + + def convert_geography(value: Any) -> Any: + if value is None: + return None + elif ( + isinstance(value, dict) + and all(key in value for key in ["wkb", "srid"]) + and isinstance(value["wkb"], bytes) + and isinstance(value["srid"], int) + ): + return Geography.fromWKB(value["wkb"], value["srid"]) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY") + + return convert_geography + + elif isinstance(dataType, GeometryType): + + def convert_geometry(value: Any) -> Any: + if value is None: + return None + elif ( + isinstance(value, dict) + and all(key in value for key in ["wkb", "srid"]) + and isinstance(value["wkb"], bytes) + and isinstance(value["srid"], int) + ): + return Geometry.fromWKB(value["wkb"], value["srid"]) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOMETRY") + + return convert_geometry + else: if none_on_identity: return None diff --git a/python/pyspark/sql/tests/connect/test_parity_geographytype.py b/python/pyspark/sql/tests/connect/test_parity_geographytype.py new file mode 100644 index 000000000000..501bbed20ff1 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_geographytype.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.test_geographytype import GeographyTypeTestMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class GeographyTypeParityTest(GeographyTypeTestMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_geographytype import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_geometrytype.py b/python/pyspark/sql/tests/connect/test_parity_geometrytype.py new file mode 100644 index 000000000000..b95321b3c61b --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_geometrytype.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.test_geometrytype import GeometryTypeTestMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class GeometryTypeParityTest(GeometryTypeTestMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_geometrytype import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 440100dba931..8aae39880072 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -90,6 +90,8 @@ "TimestampNTZType", "DecimalType", "DoubleType", + "Geography", + "Geometry", "FloatType", "ByteType", "IntegerType", @@ -616,6 +618,20 @@ def jsonValue(self) -> Union[str, Dict[str, Any]]: # The JSON representation always uses the CRS and algorithm value. return f"geography({self._crs}, {self._alg})" + def needConversion(self) -> bool: + return True + + def fromInternal(self, obj: Dict) -> Optional["Geography"]: + if obj is None or not all(key in obj for key in ["srid", "bytes"]): + return None + return Geography(obj["bytes"], obj["srid"]) + + def toInternal(self, geography: Any) -> Any: + if geography is None: + return None + assert isinstance(geography, Geography) + return {"srid": geography.srid, "wkb": geography.wkb} + class GeometryType(SpatialType): """ @@ -700,6 +716,20 @@ def jsonValue(self) -> Union[str, Dict[str, Any]]: # The JSON representation always uses the CRS value. return f"geometry({self._crs})" + def needConversion(self) -> bool: + return True + + def fromInternal(self, obj: Dict) -> Optional["Geometry"]: + if obj is None or not all(key in obj for key in ["srid", "bytes"]): + return None + return Geometry(obj["bytes"], obj["srid"]) + + def toInternal(self, geometry: Any) -> Any: + if geometry is None: + return None + assert isinstance(geometry, Geometry) + return {"srid": geometry.srid, "wkb": geometry.wkb} + class ByteType(IntegralType): """Byte data type, representing signed 8-bit integers.""" @@ -2039,6 +2069,144 @@ def parseJson(cls, json_str: str) -> "VariantVal": return VariantVal(value, metadata) +class Geography: + """ + A class to represent a Geography value in Python. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + wkb : bytes + The bytes representing the WKB of Geography. + + srid : integer + The integer value representing SRID of Geography. + + Methods + ------- + getBytes() + Returns the WKB of Geography. + + getSrid() + Returns the SRID of Geography. + + Examples + -------- + >>> g = Geography.fromWKB(bytes.fromhex('010100000000000000000031400000000000001c40'), 4326) + >>> g.getBytes().hex() + '010100000000000000000031400000000000001c40' + >>> g.getSrid() + 4326 + """ + + def __init__(self, wkb: bytes, srid: int): + self.wkb = wkb + self.srid = srid + + def __str__(self) -> str: + return "Geography(%r, %d)" % (self.wkb, self.srid) + + def __repr__(self) -> str: + return "Geography(%r, %d)" % (self.wkb, self.srid) + + def getSrid(self) -> int: + """ + Returns the SRID of Geography. + """ + return self.srid + + def getBytes(self) -> bytes: + """ + Returns the WKB of Geography. + """ + return self.wkb + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Geography): + # Don't attempt to compare against unrelated types. + return NotImplemented + + return self.wkb == other.wkb and self.srid == other.srid + + @classmethod + def fromWKB(cls, wkb: bytes, srid: int) -> "Geography": + """ + Construct Python Geography object from WKB. + :return: Python representation of the Geography type value. + """ + return Geography(wkb, srid) + + +class Geometry: + """ + A class to represent a Geometry value in Python. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + wkb : bytes + The bytes representing the WKB of Geometry. + + srid : integer + The integer value representing SRID of Geometry. + + Methods + ------- + getBytes() + Returns the WKB of Geometry. + + getSrid() + Returns the SRID of Geometry. + + Examples + -------- + >>> g = Geometry.fromWKB(bytes.fromhex('010100000000000000000031400000000000001c40'), 0) + >>> g.getBytes().hex() + '010100000000000000000031400000000000001c40' + >>> g.getSrid() + 0 + """ + + def __init__(self, wkb: bytes, srid: int): + self.wkb = wkb + self.srid = srid + + def __str__(self) -> str: + return "Geometry(%r, %d)" % (self.wkb, self.srid) + + def __repr__(self) -> str: + return "Geometry(%r, %d)" % (self.wkb, self.srid) + + def getSrid(self) -> int: + """ + Returns the SRID of Geometry. + """ + return self.srid + + def getBytes(self) -> bytes: + """ + Returns the WKB of Geometry. + """ + return self.wkb + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Geometry): + # Don't attempt to compare against unrelated types. + return NotImplemented + + return self.wkb == other.wkb and self.srid == other.srid + + @classmethod + def fromWKB(cls, wkb: bytes, srid: int) -> "Geometry": + """ + Construct Python Geometry object from WKB. + :return: Python representation of the Geometry type value. + """ + return Geometry(wkb, srid) + + _atomic_types: List[Type[DataType]] = [ StringType, CharType,