Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions python/pyspark/sql/geo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Optional


# Class for maintaining information about a spatial reference system (SRS).
@dataclass(frozen=True)
class SpatialReferenceSystemInformation:
# Field storing the spatial reference identifier (SRID) value of this SRS.
srid: int
# Field storing the string ID of the corresponding coordinate reference system (CRS).
string_id: str
# Field indicating whether the spatial reference system (SRS) is geographic or not.
is_geographic: bool


# Class for maintaining the mappings between supported SRID/CRS values and the corresponding SRS.
class SpatialReferenceSystemCache:
# We use a singleton pattern, which is aligned with the JVM side / Scala implementation.
_instance: Optional["SpatialReferenceSystemCache"] = None

@classmethod
def instance(cls) -> "SpatialReferenceSystemCache":
if cls._instance is None:
cls._instance = SpatialReferenceSystemCache()
return cls._instance

def __init__(self) -> None:
# Hash map for defining the mappings from the integer SRID to the full SRS information.
self._srid_to_srs: Dict[int, SpatialReferenceSystemInformation] = {}
# Hash map for defining the mappings from the string ID to the full SRS information.
self._string_id_to_srs: Dict[str, SpatialReferenceSystemInformation] = {}
self._populate_srs_information_mapping()

# Helper method for building the SRID-to-SRS and stringID-to-SRS mappings.
def _populate_srs_information_mapping(self) -> None:
# Currently, we only support a limited set of SRID / CRS values, even on Scala side. The
# SRS list below will be updated soon, and the maps will be populated with more SRS data.
srs_list = [
SpatialReferenceSystemInformation(0, "SRID:0", False),
SpatialReferenceSystemInformation(3857, "EPSG:3857", False),
SpatialReferenceSystemInformation(4326, "OGC:CRS84", True),
]
# Populate the mappings using the same SRS information objects, avoiding any duplication.
for srs in srs_list:
self._srid_to_srs[srs.srid] = srs
self._string_id_to_srs[srs.string_id] = srs

# Returns the SRS corresponding to the input SRID. If not supported, returns `None`.
def get_srs_by_srid(self, srid: int) -> Optional[SpatialReferenceSystemInformation]:
return self._srid_to_srs.get(srid)

# Returns the SRS corresponding to the input string ID. If not supported, returns `None`.
def get_srs_by_string_id(self, string_id: str) -> Optional[SpatialReferenceSystemInformation]:
return self._string_id_to_srs.get(string_id)


# Class for providing SRS mappings for geographic spatial reference systems.
class GeographicSpatialReferenceSystemMapper:
# Returns the string ID corresponding to the input SRID. If not supported, returns `None`.
@staticmethod
def get_string_id(srid: int) -> Optional[str]:
srs = SpatialReferenceSystemCache.instance().get_srs_by_srid(srid)
return srs.string_id if srs is not None and srs.is_geographic else None

# Returns the SRID corresponding to the input string ID. If not supported, returns `None`.
@staticmethod
def get_srid(string_id: str) -> Optional[int]:
srs = SpatialReferenceSystemCache.instance().get_srs_by_string_id(string_id)
return srs.srid if srs is not None and srs.is_geographic else None


# Class for providing SRS mappings for cartesian spatial reference systems.
class CartesianSpatialReferenceSystemMapper:
# Returns the string ID corresponding to the input SRID. If not supported, returns `None`.
@staticmethod
def get_string_id(srid: int) -> Optional[str]:
srs = SpatialReferenceSystemCache.instance().get_srs_by_srid(srid)
return srs.string_id if srs is not None else None

# Returns the SRID corresponding to the input string ID. If not supported, returns `None`.
@staticmethod
def get_srid(string_id: str) -> Optional[int]:
srs = SpatialReferenceSystemCache.instance().get_srs_by_string_id(string_id)
return srs.srid if srs is not None else None
10 changes: 7 additions & 3 deletions python/pyspark/sql/tests/test_geometrytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class GeometryTypeTestMixin:
def test_geometrytype_specified_valid_srid(self):
"""Test that GeometryType is constructed correctly when a valid SRID is specified."""

supported_srid = {4326: "OGC:CRS84"}
supported_srid = {0: "SRID:0", 3857: "EPSG:3857", 4326: "OGC:CRS84"}

for srid, crs in supported_srid.items():
geometry_type = GeometryType(srid)
Expand Down Expand Up @@ -65,18 +65,20 @@ def test_geometrytype_any_specifier(self):
def test_geometrytype_same_srid_values(self):
"""Test that two GeometryTypes with specified SRIDs have the same SRID values."""

for srid in [4326]:
for srid in [0, 3857, 4326]:
geometry_type_1 = GeometryType(srid)
geometry_type_2 = GeometryType(srid)
self.assertEqual(geometry_type_1.srid, geometry_type_2.srid)

def test_geometrytype_different_srid_values(self):
"""Test that two GeometryTypes with specified SRIDs have different SRID values."""

for srid in [4326]:
for srid in [0, 4326]:
geometry_type_1 = GeometryType(srid)
geometry_type_2 = GeometryType("ANY")
self.assertNotEqual(geometry_type_1.srid, geometry_type_2.srid)
geometry_type_3 = GeometryType(3857)
self.assertNotEqual(geometry_type_1.srid, geometry_type_3.srid)

# The tests below verify GEOMETRY type JSON parsing based on the CRS value.

Expand All @@ -96,6 +98,8 @@ def test_geometrytype_from_valid_crs(self):
"""Test that GeometryType construction passes when a valid CRS is specified."""

supported_crs = {
"SRID:0": 0,
"EPSG:3857": 3857,
"OGC:CRS84": 4326,
}
for valid_crs, srid in supported_crs.items():
Expand Down
20 changes: 12 additions & 8 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@
PySparkAttributeError,
PySparkKeyError,
)
from pyspark.sql.geo_utils import (
GeographicSpatialReferenceSystemMapper as _GeographicSRSMapper,
CartesianSpatialReferenceSystemMapper as _CartesianSRSMapper,
)

if TYPE_CHECKING:
import numpy as np
Expand Down Expand Up @@ -549,8 +553,8 @@ def __init__(self, srid: int | str):
if srid == "ANY":
self.srid = GeographyType.MIXED_SRID
self._crs = GeographyType.MIXED_CRS
# Otherwise, the parameterized GEOMETRY type syntax requires a valid SRID value.
elif not isinstance(srid, int) or srid != GeographyType.DEFAULT_SRID:
# Otherwise, the parameterized GEOGRAPHY type requires a valid integer SRID value.
elif not isinstance(srid, int) or (crs := _GeographicSRSMapper.get_string_id(srid)) is None:
raise IllegalArgumentException(
errorClass="ST_INVALID_SRID_VALUE",
messageParameters={
Expand All @@ -559,7 +563,7 @@ def __init__(self, srid: int | str):
)
else:
self.srid = srid
self._crs = GeographyType.DEFAULT_CRS
self._crs = crs
self._alg = GeographyType.DEFAULT_ALG

@classmethod
Expand All @@ -577,7 +581,7 @@ def _from_crs(cls, crs: str, alg: str) -> "GeographyType":
if crs.lower() == cls.MIXED_CRS.lower():
return GeographyType("ANY")
# Otherwise, JSON parsing for the GEOGRAPHY type requires a valid CRS value.
srid = GeographyType.DEFAULT_SRID if crs == "OGC:CRS84" else None
srid = _GeographicSRSMapper.get_srid(crs)
if srid is None:
raise IllegalArgumentException(
errorClass="ST_INVALID_CRS_VALUE",
Expand Down Expand Up @@ -638,8 +642,8 @@ def __init__(self, srid: int | str):
if srid == "ANY":
self.srid = GeometryType.MIXED_SRID
self._crs = GeometryType.MIXED_CRS
# Otherwise, the parameterized GEOMETRY type syntax requires a valid SRID value.
elif not isinstance(srid, int) or srid != GeometryType.DEFAULT_SRID:
# Otherwise, the parameterized GEOMETRY type requires a valid integer SRID value.
elif not isinstance(srid, int) or (crs := _CartesianSRSMapper.get_string_id(srid)) is None:
raise IllegalArgumentException(
errorClass="ST_INVALID_SRID_VALUE",
messageParameters={
Expand All @@ -649,7 +653,7 @@ def __init__(self, srid: int | str):
# If the SRID is valid, initialize the GEOMETRY type with the corresponding CRS value.
else:
self.srid = srid
self._crs = GeometryType.DEFAULT_CRS
self._crs = crs

""" JSON parsing logic for the GEOMETRY type relies on the CRS value, instead of the SRID.
The method can accept either a single valid geometric string CRS value, or a special case
Expand All @@ -662,7 +666,7 @@ def _from_crs(cls, crs: str) -> "GeometryType":
if crs.lower() == cls.MIXED_CRS.lower():
return GeometryType("ANY")
# Otherwise, JSON parsing for the GEOMETRY type requires a valid CRS value.
srid = GeometryType.DEFAULT_SRID if crs == "OGC:CRS84" else None
srid = _CartesianSRSMapper.get_srid(crs)
if srid is None:
raise IllegalArgumentException(
errorClass="ST_INVALID_CRS_VALUE",
Expand Down