diff --git a/python/pyspark/sql/geo_utils.py b/python/pyspark/sql/geo_utils.py new file mode 100644 index 0000000000000..904f82a9946f9 --- /dev/null +++ b/python/pyspark/sql/geo_utils.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Optional + + +# Class for maintaining information about a spatial reference system (SRS). +@dataclass(frozen=True) +class SpatialReferenceSystemInformation: + # Field storing the spatial reference identifier (SRID) value of this SRS. + srid: int + # Field storing the string ID of the corresponding coordinate reference system (CRS). + string_id: str + # Field indicating whether the spatial reference system (SRS) is geographic or not. + is_geographic: bool + + +# Class for maintaining the mappings between supported SRID/CRS values and the corresponding SRS. +class SpatialReferenceSystemCache: + # We use a singleton pattern, which is aligned with the JVM side / Scala implementation. + _instance: Optional["SpatialReferenceSystemCache"] = None + + @classmethod + def instance(cls) -> "SpatialReferenceSystemCache": + if cls._instance is None: + cls._instance = SpatialReferenceSystemCache() + return cls._instance + + def __init__(self) -> None: + # Hash map for defining the mappings from the integer SRID to the full SRS information. + self._srid_to_srs: Dict[int, SpatialReferenceSystemInformation] = {} + # Hash map for defining the mappings from the string ID to the full SRS information. + self._string_id_to_srs: Dict[str, SpatialReferenceSystemInformation] = {} + self._populate_srs_information_mapping() + + # Helper method for building the SRID-to-SRS and stringID-to-SRS mappings. + def _populate_srs_information_mapping(self) -> None: + # Currently, we only support a limited set of SRID / CRS values, even on Scala side. The + # SRS list below will be updated soon, and the maps will be populated with more SRS data. + srs_list = [ + SpatialReferenceSystemInformation(0, "SRID:0", False), + SpatialReferenceSystemInformation(3857, "EPSG:3857", False), + SpatialReferenceSystemInformation(4326, "OGC:CRS84", True), + ] + # Populate the mappings using the same SRS information objects, avoiding any duplication. + for srs in srs_list: + self._srid_to_srs[srs.srid] = srs + self._string_id_to_srs[srs.string_id] = srs + + # Returns the SRS corresponding to the input SRID. If not supported, returns `None`. + def get_srs_by_srid(self, srid: int) -> Optional[SpatialReferenceSystemInformation]: + return self._srid_to_srs.get(srid) + + # Returns the SRS corresponding to the input string ID. If not supported, returns `None`. + def get_srs_by_string_id(self, string_id: str) -> Optional[SpatialReferenceSystemInformation]: + return self._string_id_to_srs.get(string_id) + + +# Class for providing SRS mappings for geographic spatial reference systems. +class GeographicSpatialReferenceSystemMapper: + # Returns the string ID corresponding to the input SRID. If not supported, returns `None`. + @staticmethod + def get_string_id(srid: int) -> Optional[str]: + srs = SpatialReferenceSystemCache.instance().get_srs_by_srid(srid) + return srs.string_id if srs is not None and srs.is_geographic else None + + # Returns the SRID corresponding to the input string ID. If not supported, returns `None`. + @staticmethod + def get_srid(string_id: str) -> Optional[int]: + srs = SpatialReferenceSystemCache.instance().get_srs_by_string_id(string_id) + return srs.srid if srs is not None and srs.is_geographic else None + + +# Class for providing SRS mappings for cartesian spatial reference systems. +class CartesianSpatialReferenceSystemMapper: + # Returns the string ID corresponding to the input SRID. If not supported, returns `None`. + @staticmethod + def get_string_id(srid: int) -> Optional[str]: + srs = SpatialReferenceSystemCache.instance().get_srs_by_srid(srid) + return srs.string_id if srs is not None else None + + # Returns the SRID corresponding to the input string ID. If not supported, returns `None`. + @staticmethod + def get_srid(string_id: str) -> Optional[int]: + srs = SpatialReferenceSystemCache.instance().get_srs_by_string_id(string_id) + return srs.srid if srs is not None else None diff --git a/python/pyspark/sql/tests/test_geometrytype.py b/python/pyspark/sql/tests/test_geometrytype.py index 764b5d6b45de3..d4d7569229d81 100644 --- a/python/pyspark/sql/tests/test_geometrytype.py +++ b/python/pyspark/sql/tests/test_geometrytype.py @@ -27,7 +27,7 @@ class GeometryTypeTestMixin: def test_geometrytype_specified_valid_srid(self): """Test that GeometryType is constructed correctly when a valid SRID is specified.""" - supported_srid = {4326: "OGC:CRS84"} + supported_srid = {0: "SRID:0", 3857: "EPSG:3857", 4326: "OGC:CRS84"} for srid, crs in supported_srid.items(): geometry_type = GeometryType(srid) @@ -65,7 +65,7 @@ def test_geometrytype_any_specifier(self): def test_geometrytype_same_srid_values(self): """Test that two GeometryTypes with specified SRIDs have the same SRID values.""" - for srid in [4326]: + for srid in [0, 3857, 4326]: geometry_type_1 = GeometryType(srid) geometry_type_2 = GeometryType(srid) self.assertEqual(geometry_type_1.srid, geometry_type_2.srid) @@ -73,10 +73,12 @@ def test_geometrytype_same_srid_values(self): def test_geometrytype_different_srid_values(self): """Test that two GeometryTypes with specified SRIDs have different SRID values.""" - for srid in [4326]: + for srid in [0, 4326]: geometry_type_1 = GeometryType(srid) geometry_type_2 = GeometryType("ANY") self.assertNotEqual(geometry_type_1.srid, geometry_type_2.srid) + geometry_type_3 = GeometryType(3857) + self.assertNotEqual(geometry_type_1.srid, geometry_type_3.srid) # The tests below verify GEOMETRY type JSON parsing based on the CRS value. @@ -96,6 +98,8 @@ def test_geometrytype_from_valid_crs(self): """Test that GeometryType construction passes when a valid CRS is specified.""" supported_crs = { + "SRID:0": 0, + "EPSG:3857": 3857, "OGC:CRS84": 4326, } for valid_crs, srid in supported_crs.items(): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f45822bfa6d25..7b2bbf080dfa1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -64,6 +64,10 @@ PySparkAttributeError, PySparkKeyError, ) +from pyspark.sql.geo_utils import ( + GeographicSpatialReferenceSystemMapper as _GeographicSRSMapper, + CartesianSpatialReferenceSystemMapper as _CartesianSRSMapper, +) if TYPE_CHECKING: import numpy as np @@ -549,8 +553,8 @@ def __init__(self, srid: int | str): if srid == "ANY": self.srid = GeographyType.MIXED_SRID self._crs = GeographyType.MIXED_CRS - # Otherwise, the parameterized GEOMETRY type syntax requires a valid SRID value. - elif not isinstance(srid, int) or srid != GeographyType.DEFAULT_SRID: + # Otherwise, the parameterized GEOGRAPHY type requires a valid integer SRID value. + elif not isinstance(srid, int) or (crs := _GeographicSRSMapper.get_string_id(srid)) is None: raise IllegalArgumentException( errorClass="ST_INVALID_SRID_VALUE", messageParameters={ @@ -559,7 +563,7 @@ def __init__(self, srid: int | str): ) else: self.srid = srid - self._crs = GeographyType.DEFAULT_CRS + self._crs = crs self._alg = GeographyType.DEFAULT_ALG @classmethod @@ -577,7 +581,7 @@ def _from_crs(cls, crs: str, alg: str) -> "GeographyType": if crs.lower() == cls.MIXED_CRS.lower(): return GeographyType("ANY") # Otherwise, JSON parsing for the GEOGRAPHY type requires a valid CRS value. - srid = GeographyType.DEFAULT_SRID if crs == "OGC:CRS84" else None + srid = _GeographicSRSMapper.get_srid(crs) if srid is None: raise IllegalArgumentException( errorClass="ST_INVALID_CRS_VALUE", @@ -638,8 +642,8 @@ def __init__(self, srid: int | str): if srid == "ANY": self.srid = GeometryType.MIXED_SRID self._crs = GeometryType.MIXED_CRS - # Otherwise, the parameterized GEOMETRY type syntax requires a valid SRID value. - elif not isinstance(srid, int) or srid != GeometryType.DEFAULT_SRID: + # Otherwise, the parameterized GEOMETRY type requires a valid integer SRID value. + elif not isinstance(srid, int) or (crs := _CartesianSRSMapper.get_string_id(srid)) is None: raise IllegalArgumentException( errorClass="ST_INVALID_SRID_VALUE", messageParameters={ @@ -649,7 +653,7 @@ def __init__(self, srid: int | str): # If the SRID is valid, initialize the GEOMETRY type with the corresponding CRS value. else: self.srid = srid - self._crs = GeometryType.DEFAULT_CRS + self._crs = crs """ JSON parsing logic for the GEOMETRY type relies on the CRS value, instead of the SRID. The method can accept either a single valid geometric string CRS value, or a special case @@ -662,7 +666,7 @@ def _from_crs(cls, crs: str) -> "GeometryType": if crs.lower() == cls.MIXED_CRS.lower(): return GeometryType("ANY") # Otherwise, JSON parsing for the GEOMETRY type requires a valid CRS value. - srid = GeometryType.DEFAULT_SRID if crs == "OGC:CRS84" else None + srid = _CartesianSRSMapper.get_srid(crs) if srid is None: raise IllegalArgumentException( errorClass="ST_INVALID_CRS_VALUE",