Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions python/pyspark/sql/tests/test_geographytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,47 @@ def test_geographytype_different_srid_values(self):
geography_type_2 = GeographyType("ANY")
self.assertNotEqual(geography_type_1.srid, geography_type_2.srid)

def test_geographytype_from_invalid_crs(self):
"""Test that GeographyType construction fails when an invalid CRS is specified."""

for invalid_crs in ["srid", "any", "ogccrs84", "ogc:crs84", "ogc:CRS84", "asdf", ""]:
with self.assertRaises(IllegalArgumentException) as error_context:
GeographyType._from_crs(invalid_crs, "SPHERICAL")
crs_header = "[ST_INVALID_CRS_VALUE] Invalid or unsupported CRS"
self.assertEqual(
str(error_context.exception),
f"{crs_header} (coordinate reference system) value: '{invalid_crs}'.",
)

# The tests below verify GEOGRAPHY type JSON parsing based on CRS and algorithm.

def test_geographytype_from_invalid_algorithm(self):
"""Test that GeographyType construction fails when an invalid CRS is specified."""

for invalid_alg in ["alg", "algorithm", "KARNEY", "spherical", "SPHEROID", "asdf", ""]:
with self.assertRaises(IllegalArgumentException) as error_context:
GeographyType._from_crs("OGC:CRS84", invalid_alg)
alg_header = "[ST_INVALID_ALGORITHM_VALUE] Invalid or unsupported"
self.assertEqual(
str(error_context.exception),
f"{alg_header} edge interpolation algorithm value: '{invalid_alg}'.",
)

def test_geographytype_from_valid_crs_and_algorithm(self):
"""Test that GeographyType construction passes when valid CRS & ALG are specified."""

supported_crs = {
"OGC:CRS84": 4326,
}
for valid_crs, srid in supported_crs.items():
for valid_alg in ["SPHERICAL"]:
geography_type = GeographyType._from_crs(valid_crs, valid_alg)
self.assertEqual(geography_type.srid, srid)
self.assertEqual(geography_type.typeName(), "geography")
self.assertEqual(geography_type.simpleString(), f"geography({srid})")
self.assertEqual(geography_type.jsonValue(), f"geography({valid_crs}, {valid_alg})")
self.assertEqual(repr(geography_type), f"GeographyType({srid})")


class GeographyTypeTest(GeographyTypeTestMixin, ReusedSQLTestCase):
pass
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/tests/test_geometrytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,34 @@ def test_geometrytype_different_srid_values(self):
geometry_type_2 = GeometryType("ANY")
self.assertNotEqual(geometry_type_1.srid, geometry_type_2.srid)

# The tests below verify GEOMETRY type JSON parsing based on the CRS value.

def test_geometrytype_from_invalid_crs(self):
"""Test that GeometryType construction fails when an invalid CRS is specified."""

for invalid_crs in ["srid", "any", "ogccrs84", "ogc:crs84", "ogc:CRS84", "asdf", ""]:
with self.assertRaises(IllegalArgumentException) as error_context:
GeometryType._from_crs(invalid_crs)
crs_header = "[ST_INVALID_CRS_VALUE] Invalid or unsupported CRS"
self.assertEqual(
str(error_context.exception),
f"{crs_header} (coordinate reference system) value: '{invalid_crs}'.",
)

def test_geometrytype_from_valid_crs(self):
"""Test that GeometryType construction passes when a valid CRS is specified."""

supported_crs = {
"OGC:CRS84": 4326,
}
for valid_crs, srid in supported_crs.items():
geometry_type = GeometryType._from_crs(valid_crs)
self.assertEqual(geometry_type.srid, srid)
self.assertEqual(geometry_type.typeName(), "geometry")
self.assertEqual(geometry_type.simpleString(), f"geometry({srid})")
self.assertEqual(geometry_type.jsonValue(), f"geometry({valid_crs})")
self.assertEqual(repr(geometry_type), f"GeometryType({srid})")


class GeometryTypeTest(GeometryTypeTestMixin, ReusedSQLTestCase):
pass
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def _from_crs(cls, crs: str, alg: str) -> "GeographyType":
# Algorithm value must be validated, although only SPHERICAL is supported currently.
if alg != cls.DEFAULT_ALG:
raise IllegalArgumentException(
errorClass="INVALID_ALGORITHM_VALUE",
errorClass="ST_INVALID_ALGORITHM_VALUE",
messageParameters={
"alg": str(alg),
},
Expand Down