Skip to content
Closed
4 changes: 3 additions & 1 deletion python/pyspark/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
- :class:`pyspark.sql.Window`
For working with window functions.
"""
from pyspark.sql.types import Row, VariantVal
from pyspark.sql.types import Geography, Geometry, Row, VariantVal
from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration, UDTFRegistration
from pyspark.sql.session import SparkSession
from pyspark.sql.column import Column
Expand Down Expand Up @@ -69,6 +69,8 @@
"DataFrameNaFunctions",
"DataFrameStatFunctions",
"VariantVal",
"Geography",
"Geometry",
"Window",
"WindowSpec",
"DataFrameReader",
Expand Down
120 changes: 120 additions & 0 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
UserDefinedType,
VariantType,
VariantVal,
GeometryType,
Geometry,
GeographyType,
Geography,
_create_row,
)
from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError
Expand Down Expand Up @@ -202,6 +206,28 @@ def to_arrow_type(
pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}),
]
arrow_type = pa.struct(fields)
elif type(dt) == GeometryType:
fields = [
pa.field("srid", pa.int32(), nullable=False),
pa.field(
"wkb",
pa.binary(),
nullable=False,
metadata={b"geometry": b"true", b"srid": str(dt.srid)},
),
]
arrow_type = pa.struct(fields)
elif type(dt) == GeographyType:
fields = [
pa.field("srid", pa.int32(), nullable=False),
pa.field(
"wkb",
pa.binary(),
nullable=False,
metadata={b"geography": b"true", b"srid": str(dt.srid)},
),
]
arrow_type = pa.struct(fields)
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
Expand Down Expand Up @@ -272,6 +298,38 @@ def is_variant(at: "pa.DataType") -> bool:
) and any(field.name == "value" for field in at)


def is_geometry(at: "pa.DataType") -> bool:
"""Check if a PyArrow struct data type represents a geometry"""
import pyarrow.types as types

assert types.is_struct(at)

return any(
(
field.name == "wkb"
and b"geometry" in field.metadata
and field.metadata[b"geometry"] == b"true"
)
for field in at
) and any(field.name == "srid" for field in at)


def is_geography(at: "pa.DataType") -> bool:
"""Check if a PyArrow struct data type represents a geography"""
import pyarrow.types as types

assert types.is_struct(at)

return any(
(
field.name == "wkb"
and b"geography" in field.metadata
and field.metadata[b"geography"] == b"true"
)
for field in at
) and any(field.name == "srid" for field in at)


def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType:
"""Convert pyarrow type to Spark data type."""
import pyarrow.types as types
Expand Down Expand Up @@ -337,6 +395,18 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
elif types.is_struct(at):
if is_variant(at):
return VariantType()
elif is_geometry(at):
srid = int(at.field("wkb").metadata.get(b"srid"))
if srid == GeometryType.MIXED_SRID:
return GeometryType("ANY")
else:
return GeometryType(srid)
elif is_geography(at):
srid = int(at.field("wkb").metadata.get(b"srid"))
if srid == GeographyType.MIXED_SRID:
return GeographyType("ANY")
else:
return GeographyType(srid)
return StructType(
[
StructField(
Expand Down Expand Up @@ -1098,6 +1168,40 @@ def convert_variant(value: Any) -> Any:

return convert_variant

elif isinstance(dt, GeographyType):

def convert_geography(value: Any) -> Any:
if value is None:
return None
elif (
isinstance(value, dict)
and all(key in value for key in ["wkb", "srid"])
and isinstance(value["wkb"], bytes)
and isinstance(value["srid"], int)
):
return Geography.fromWKB(value["wkb"], value["srid"])
else:
raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")

return convert_geography

elif isinstance(dt, GeometryType):

def convert_geometry(value: Any) -> Any:
if value is None:
return None
elif (
isinstance(value, dict)
and all(key in value for key in ["wkb", "srid"])
and isinstance(value["wkb"], bytes)
and isinstance(value["srid"], int)
):
return Geometry.fromWKB(value["wkb"], value["srid"])
else:
raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")

return convert_geometry

else:
return None

Expand Down Expand Up @@ -1360,6 +1464,22 @@ def convert_variant(variant: Any) -> Any:

return convert_variant

elif isinstance(dt, GeographyType):

def convert_geography(value: Any) -> Any:
assert isinstance(value, Geography)
return {"srid": value.srid, "wkb": value.wkb}

return convert_geography

elif isinstance(dt, GeometryType):

def convert_geometry(value: Any) -> Any:
assert isinstance(value, Geometry)
return {"srid": value.srid, "wkb": value.wkb}

return convert_geometry

return None

conv = _converter(data_type)
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def test_apply_schema_to_dict_and_rows(self):
def test_apply_schema_to_row(self):
super().test_apply_schema_to_row()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_geospatial_create_dataframe_rdd(self):
super().test_geospatial_create_dataframe_rdd()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_create_dataframe_schema_mismatch(self):
super().test_create_dataframe_schema_mismatch()
Expand Down
Loading