From 3f6ea4ffafe1249e09497cb182163e07cf2fc922 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Tue, 4 Nov 2025 08:27:33 +0100 Subject: [PATCH 1/5] Initial commit --- dev/sparktestsupport/modules.py | 2 + python/pyspark/errors/error-conditions.json | 10 ++ python/pyspark/sql/connect/proto/types_pb2.py | 116 +++++++++--------- .../pyspark/sql/connect/proto/types_pb2.pyi | 59 +++++++++ python/pyspark/sql/connect/types.py | 18 +++ python/pyspark/sql/conversion.py | 74 +++++++++++ .../connect/test_parity_geographytype.py | 38 ++++++ .../tests/connect/test_parity_geometrytype.py | 38 ++++++ .../main/protobuf/spark/connect/types.proto | 18 ++- 9 files changed, 314 insertions(+), 59 deletions(-) create mode 100644 python/pyspark/sql/tests/connect/test_parity_geographytype.py create mode 100644 python/pyspark/sql/tests/connect/test_parity_geometrytype.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 07ac4c76b91a6..aa8ca58a5a75f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1114,6 +1114,8 @@ def __hash__(self): "pyspark.sql.tests.connect.test_connect_retry", "pyspark.sql.tests.connect.test_connect_session", "pyspark.sql.tests.connect.test_connect_stat", + "pyspark.sql.tests.connect.test_parity_geographytype", + "pyspark.sql.tests.connect.test_parity_geometrytype", "pyspark.sql.tests.connect.test_parity_datasources", "pyspark.sql.tests.connect.test_parity_errors", "pyspark.sql.tests.connect.test_parity_catalog", diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index d169e6293a1ba..51bbdd8625164 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -549,6 +549,16 @@ " and should be of the same length, got and ." ] }, + "MALFORMED_GEOGRAPHY": { + "message": [ + "Geography binary is malformed. Please check the data source is valid." + ] + }, + "MALFORMED_GEOMETRY": { + "message": [ + "Geometry binary is malformed. Please check the data source is valid." + ] + }, "MALFORMED_VARIANT": { "message": [ "Variant binary is malformed. Please check the data source is valid." diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py index 9a52129103ad5..4e35f6b8911a2 100644 --- a/python/pyspark/sql/connect/proto/types_pb2.py +++ b/python/pyspark/sql/connect/proto/types_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b"\n\x19spark/connect/types.proto\x12\rspark.connect\"\xac#\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0b\x32\x1d.spark.connect.DataType.ShortH\x00R\x05short\x12;\n\x07integer\x18\x06 \x01(\x0b\x32\x1f.spark.connect.DataType.IntegerH\x00R\x07integer\x12\x32\n\x04long\x18\x07 \x01(\x0b\x32\x1c.spark.connect.DataType.LongH\x00R\x04long\x12\x35\n\x05\x66loat\x18\x08 \x01(\x0b\x32\x1d.spark.connect.DataType.FloatH\x00R\x05\x66loat\x12\x38\n\x06\x64ouble\x18\t \x01(\x0b\x32\x1e.spark.connect.DataType.DoubleH\x00R\x06\x64ouble\x12;\n\x07\x64\x65\x63imal\x18\n \x01(\x0b\x32\x1f.spark.connect.DataType.DecimalH\x00R\x07\x64\x65\x63imal\x12\x38\n\x06string\x18\x0b \x01(\x0b\x32\x1e.spark.connect.DataType.StringH\x00R\x06string\x12\x32\n\x04\x63har\x18\x0c \x01(\x0b\x32\x1c.spark.connect.DataType.CharH\x00R\x04\x63har\x12<\n\x08var_char\x18\r \x01(\x0b\x32\x1f.spark.connect.DataType.VarCharH\x00R\x07varChar\x12\x32\n\x04\x64\x61te\x18\x0e \x01(\x0b\x32\x1c.spark.connect.DataType.DateH\x00R\x04\x64\x61te\x12\x41\n\ttimestamp\x18\x0f \x01(\x0b\x32!.spark.connect.DataType.TimestampH\x00R\ttimestamp\x12K\n\rtimestamp_ntz\x18\x10 \x01(\x0b\x32$.spark.connect.DataType.TimestampNTZH\x00R\x0ctimestampNtz\x12W\n\x11\x63\x61lendar_interval\x18\x11 \x01(\x0b\x32(.spark.connect.DataType.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12[\n\x13year_month_interval\x18\x12 \x01(\x0b\x32).spark.connect.DataType.YearMonthIntervalH\x00R\x11yearMonthInterval\x12U\n\x11\x64\x61y_time_interval\x18\x13 \x01(\x0b\x32'.spark.connect.DataType.DayTimeIntervalH\x00R\x0f\x64\x61yTimeInterval\x12\x35\n\x05\x61rray\x18\x14 \x01(\x0b\x32\x1d.spark.connect.DataType.ArrayH\x00R\x05\x61rray\x12\x38\n\x06struct\x18\x15 \x01(\x0b\x32\x1e.spark.connect.DataType.StructH\x00R\x06struct\x12/\n\x03map\x18\x16 \x01(\x0b\x32\x1b.spark.connect.DataType.MapH\x00R\x03map\x12;\n\x07variant\x18\x19 \x01(\x0b\x32\x1f.spark.connect.DataType.VariantH\x00R\x07variant\x12/\n\x03udt\x18\x17 \x01(\x0b\x32\x1b.spark.connect.DataType.UDTH\x00R\x03udt\x12>\n\x08unparsed\x18\x18 \x01(\x0b\x32 .spark.connect.DataType.UnparsedH\x00R\x08unparsed\x12\x32\n\x04time\x18\x1c \x01(\x0b\x32\x1c.spark.connect.DataType.TimeH\x00R\x04time\x1a\x43\n\x07\x42oolean\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x42yte\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05Short\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Integer\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04Long\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05\x46loat\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x42\n\x06\x44ouble\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a`\n\x06String\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x1c\n\tcollation\x18\x02 \x01(\tR\tcollation\x1a\x42\n\x06\x42inary\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04NULL\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x45\n\tTimestamp\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x44\x61te\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aH\n\x0cTimestampNTZ\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aq\n\x04Time\x12!\n\tprecision\x18\x01 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReferenceB\x0c\n\n_precision\x1aL\n\x10\x43\x61lendarInterval\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\xb3\x01\n\x11YearMonthInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1a\xb1\x01\n\x0f\x44\x61yTimeInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1aX\n\x04\x43har\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a[\n\x07VarChar\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\x99\x01\n\x07\x44\x65\x63imal\x12\x19\n\x05scale\x18\x01 \x01(\x05H\x00R\x05scale\x88\x01\x01\x12!\n\tprecision\x18\x02 \x01(\x05H\x01R\tprecision\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x08\n\x06_scaleB\x0c\n\n_precision\x1a\xa1\x01\n\x0bStructField\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x34\n\tdata_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x08\x64\x61taType\x12\x1a\n\x08nullable\x18\x03 \x01(\x08R\x08nullable\x12\x1f\n\x08metadata\x18\x04 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x7f\n\x06Struct\x12;\n\x06\x66ields\x18\x01 \x03(\x0b\x32#.spark.connect.DataType.StructFieldR\x06\x66ields\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\xa2\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12#\n\rcontains_null\x18\x02 \x01(\x08R\x0c\x63ontainsNull\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReference\x1a\xdb\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12.\n\x13value_contains_null\x18\x03 \x01(\x08R\x11valueContainsNull\x12\x38\n\x18type_variation_reference\x18\x04 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Variant\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\xa1\x02\n\x03UDT\x12\x12\n\x04type\x18\x01 \x01(\tR\x04type\x12 \n\tjvm_class\x18\x02 \x01(\tH\x00R\x08jvmClass\x88\x01\x01\x12&\n\x0cpython_class\x18\x03 \x01(\tH\x01R\x0bpythonClass\x88\x01\x01\x12;\n\x17serialized_python_class\x18\x04 \x01(\tH\x02R\x15serializedPythonClass\x88\x01\x01\x12\x37\n\x08sql_type\x18\x05 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x03R\x07sqlType\x88\x01\x01\x42\x0c\n\n_jvm_classB\x0f\n\r_python_classB\x1a\n\x18_serialized_python_classB\x0b\n\t_sql_type\x1a\x34\n\x08Unparsed\x12(\n\x10\x64\x61ta_type_string\x18\x01 \x01(\tR\x0e\x64\x61taTypeStringB\x06\n\x04kindJ\x04\x08\x1a\x10\x1bJ\x04\x08\x1b\x10\x1c\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3" + b"\n\x19spark/connect/types.proto\x12\rspark.connect\"\xd8%\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0b\x32\x1d.spark.connect.DataType.ShortH\x00R\x05short\x12;\n\x07integer\x18\x06 \x01(\x0b\x32\x1f.spark.connect.DataType.IntegerH\x00R\x07integer\x12\x32\n\x04long\x18\x07 \x01(\x0b\x32\x1c.spark.connect.DataType.LongH\x00R\x04long\x12\x35\n\x05\x66loat\x18\x08 \x01(\x0b\x32\x1d.spark.connect.DataType.FloatH\x00R\x05\x66loat\x12\x38\n\x06\x64ouble\x18\t \x01(\x0b\x32\x1e.spark.connect.DataType.DoubleH\x00R\x06\x64ouble\x12;\n\x07\x64\x65\x63imal\x18\n \x01(\x0b\x32\x1f.spark.connect.DataType.DecimalH\x00R\x07\x64\x65\x63imal\x12\x38\n\x06string\x18\x0b \x01(\x0b\x32\x1e.spark.connect.DataType.StringH\x00R\x06string\x12\x32\n\x04\x63har\x18\x0c \x01(\x0b\x32\x1c.spark.connect.DataType.CharH\x00R\x04\x63har\x12<\n\x08var_char\x18\r \x01(\x0b\x32\x1f.spark.connect.DataType.VarCharH\x00R\x07varChar\x12\x32\n\x04\x64\x61te\x18\x0e \x01(\x0b\x32\x1c.spark.connect.DataType.DateH\x00R\x04\x64\x61te\x12\x41\n\ttimestamp\x18\x0f \x01(\x0b\x32!.spark.connect.DataType.TimestampH\x00R\ttimestamp\x12K\n\rtimestamp_ntz\x18\x10 \x01(\x0b\x32$.spark.connect.DataType.TimestampNTZH\x00R\x0ctimestampNtz\x12W\n\x11\x63\x61lendar_interval\x18\x11 \x01(\x0b\x32(.spark.connect.DataType.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12[\n\x13year_month_interval\x18\x12 \x01(\x0b\x32).spark.connect.DataType.YearMonthIntervalH\x00R\x11yearMonthInterval\x12U\n\x11\x64\x61y_time_interval\x18\x13 \x01(\x0b\x32'.spark.connect.DataType.DayTimeIntervalH\x00R\x0f\x64\x61yTimeInterval\x12\x35\n\x05\x61rray\x18\x14 \x01(\x0b\x32\x1d.spark.connect.DataType.ArrayH\x00R\x05\x61rray\x12\x38\n\x06struct\x18\x15 \x01(\x0b\x32\x1e.spark.connect.DataType.StructH\x00R\x06struct\x12/\n\x03map\x18\x16 \x01(\x0b\x32\x1b.spark.connect.DataType.MapH\x00R\x03map\x12;\n\x07variant\x18\x19 \x01(\x0b\x32\x1f.spark.connect.DataType.VariantH\x00R\x07variant\x12/\n\x03udt\x18\x17 \x01(\x0b\x32\x1b.spark.connect.DataType.UDTH\x00R\x03udt\x12>\n\x08geometry\x18\x1a \x01(\x0b\x32 .spark.connect.DataType.GeometryH\x00R\x08geometry\x12\x41\n\tgeography\x18\x1b \x01(\x0b\x32!.spark.connect.DataType.GeographyH\x00R\tgeography\x12>\n\x08unparsed\x18\x18 \x01(\x0b\x32 .spark.connect.DataType.UnparsedH\x00R\x08unparsed\x12\x32\n\x04time\x18\x1c \x01(\x0b\x32\x1c.spark.connect.DataType.TimeH\x00R\x04time\x1a\x43\n\x07\x42oolean\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x42yte\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05Short\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Integer\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04Long\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05\x46loat\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x42\n\x06\x44ouble\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a`\n\x06String\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x1c\n\tcollation\x18\x02 \x01(\tR\tcollation\x1a\x42\n\x06\x42inary\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04NULL\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x45\n\tTimestamp\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x44\x61te\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aH\n\x0cTimestampNTZ\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aq\n\x04Time\x12!\n\tprecision\x18\x01 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReferenceB\x0c\n\n_precision\x1aL\n\x10\x43\x61lendarInterval\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\xb3\x01\n\x11YearMonthInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1a\xb1\x01\n\x0f\x44\x61yTimeInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1aX\n\x04\x43har\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a[\n\x07VarChar\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\x99\x01\n\x07\x44\x65\x63imal\x12\x19\n\x05scale\x18\x01 \x01(\x05H\x00R\x05scale\x88\x01\x01\x12!\n\tprecision\x18\x02 \x01(\x05H\x01R\tprecision\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x08\n\x06_scaleB\x0c\n\n_precision\x1a\xa1\x01\n\x0bStructField\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x34\n\tdata_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x08\x64\x61taType\x12\x1a\n\x08nullable\x18\x03 \x01(\x08R\x08nullable\x12\x1f\n\x08metadata\x18\x04 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x7f\n\x06Struct\x12;\n\x06\x66ields\x18\x01 \x03(\x0b\x32#.spark.connect.DataType.StructFieldR\x06\x66ields\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\xa2\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12#\n\rcontains_null\x18\x02 \x01(\x08R\x0c\x63ontainsNull\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReference\x1a\xdb\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12.\n\x13value_contains_null\x18\x03 \x01(\x08R\x11valueContainsNull\x12\x38\n\x18type_variation_reference\x18\x04 \x01(\rR\x16typeVariationReference\x1aX\n\x08Geometry\x12\x12\n\x04srid\x18\x01 \x01(\x05R\x04srid\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1aY\n\tGeography\x12\x12\n\x04srid\x18\x01 \x01(\x05R\x04srid\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Variant\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\xa1\x02\n\x03UDT\x12\x12\n\x04type\x18\x01 \x01(\tR\x04type\x12 \n\tjvm_class\x18\x02 \x01(\tH\x00R\x08jvmClass\x88\x01\x01\x12&\n\x0cpython_class\x18\x03 \x01(\tH\x01R\x0bpythonClass\x88\x01\x01\x12;\n\x17serialized_python_class\x18\x04 \x01(\tH\x02R\x15serializedPythonClass\x88\x01\x01\x12\x37\n\x08sql_type\x18\x05 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x03R\x07sqlType\x88\x01\x01\x42\x0c\n\n_jvm_classB\x0f\n\r_python_classB\x1a\n\x18_serialized_python_classB\x0b\n\t_sql_type\x1a\x34\n\x08Unparsed\x12(\n\x10\x64\x61ta_type_string\x18\x01 \x01(\tR\x0e\x64\x61taTypeStringB\x06\n\x04kindB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3" ) _globals = globals() @@ -47,59 +47,63 @@ "DESCRIPTOR" ]._serialized_options = b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated" _globals["_DATATYPE"]._serialized_start = 45 - _globals["_DATATYPE"]._serialized_end = 4569 - _globals["_DATATYPE_BOOLEAN"]._serialized_start = 1647 - _globals["_DATATYPE_BOOLEAN"]._serialized_end = 1714 - _globals["_DATATYPE_BYTE"]._serialized_start = 1716 - _globals["_DATATYPE_BYTE"]._serialized_end = 1780 - _globals["_DATATYPE_SHORT"]._serialized_start = 1782 - _globals["_DATATYPE_SHORT"]._serialized_end = 1847 - _globals["_DATATYPE_INTEGER"]._serialized_start = 1849 - _globals["_DATATYPE_INTEGER"]._serialized_end = 1916 - _globals["_DATATYPE_LONG"]._serialized_start = 1918 - _globals["_DATATYPE_LONG"]._serialized_end = 1982 - _globals["_DATATYPE_FLOAT"]._serialized_start = 1984 - _globals["_DATATYPE_FLOAT"]._serialized_end = 2049 - _globals["_DATATYPE_DOUBLE"]._serialized_start = 2051 - _globals["_DATATYPE_DOUBLE"]._serialized_end = 2117 - _globals["_DATATYPE_STRING"]._serialized_start = 2119 - _globals["_DATATYPE_STRING"]._serialized_end = 2215 - _globals["_DATATYPE_BINARY"]._serialized_start = 2217 - _globals["_DATATYPE_BINARY"]._serialized_end = 2283 - _globals["_DATATYPE_NULL"]._serialized_start = 2285 - _globals["_DATATYPE_NULL"]._serialized_end = 2349 - _globals["_DATATYPE_TIMESTAMP"]._serialized_start = 2351 - _globals["_DATATYPE_TIMESTAMP"]._serialized_end = 2420 - _globals["_DATATYPE_DATE"]._serialized_start = 2422 - _globals["_DATATYPE_DATE"]._serialized_end = 2486 - _globals["_DATATYPE_TIMESTAMPNTZ"]._serialized_start = 2488 - _globals["_DATATYPE_TIMESTAMPNTZ"]._serialized_end = 2560 - _globals["_DATATYPE_TIME"]._serialized_start = 2562 - _globals["_DATATYPE_TIME"]._serialized_end = 2675 - _globals["_DATATYPE_CALENDARINTERVAL"]._serialized_start = 2677 - _globals["_DATATYPE_CALENDARINTERVAL"]._serialized_end = 2753 - _globals["_DATATYPE_YEARMONTHINTERVAL"]._serialized_start = 2756 - _globals["_DATATYPE_YEARMONTHINTERVAL"]._serialized_end = 2935 - _globals["_DATATYPE_DAYTIMEINTERVAL"]._serialized_start = 2938 - _globals["_DATATYPE_DAYTIMEINTERVAL"]._serialized_end = 3115 - _globals["_DATATYPE_CHAR"]._serialized_start = 3117 - _globals["_DATATYPE_CHAR"]._serialized_end = 3205 - _globals["_DATATYPE_VARCHAR"]._serialized_start = 3207 - _globals["_DATATYPE_VARCHAR"]._serialized_end = 3298 - _globals["_DATATYPE_DECIMAL"]._serialized_start = 3301 - _globals["_DATATYPE_DECIMAL"]._serialized_end = 3454 - _globals["_DATATYPE_STRUCTFIELD"]._serialized_start = 3457 - _globals["_DATATYPE_STRUCTFIELD"]._serialized_end = 3618 - _globals["_DATATYPE_STRUCT"]._serialized_start = 3620 - _globals["_DATATYPE_STRUCT"]._serialized_end = 3747 - _globals["_DATATYPE_ARRAY"]._serialized_start = 3750 - _globals["_DATATYPE_ARRAY"]._serialized_end = 3912 - _globals["_DATATYPE_MAP"]._serialized_start = 3915 - _globals["_DATATYPE_MAP"]._serialized_end = 4134 - _globals["_DATATYPE_VARIANT"]._serialized_start = 4136 - _globals["_DATATYPE_VARIANT"]._serialized_end = 4203 - _globals["_DATATYPE_UDT"]._serialized_start = 4206 - _globals["_DATATYPE_UDT"]._serialized_end = 4495 - _globals["_DATATYPE_UNPARSED"]._serialized_start = 4497 - _globals["_DATATYPE_UNPARSED"]._serialized_end = 4549 + _globals["_DATATYPE"]._serialized_end = 4869 + _globals["_DATATYPE_BOOLEAN"]._serialized_start = 1778 + _globals["_DATATYPE_BOOLEAN"]._serialized_end = 1845 + _globals["_DATATYPE_BYTE"]._serialized_start = 1847 + _globals["_DATATYPE_BYTE"]._serialized_end = 1911 + _globals["_DATATYPE_SHORT"]._serialized_start = 1913 + _globals["_DATATYPE_SHORT"]._serialized_end = 1978 + _globals["_DATATYPE_INTEGER"]._serialized_start = 1980 + _globals["_DATATYPE_INTEGER"]._serialized_end = 2047 + _globals["_DATATYPE_LONG"]._serialized_start = 2049 + _globals["_DATATYPE_LONG"]._serialized_end = 2113 + _globals["_DATATYPE_FLOAT"]._serialized_start = 2115 + _globals["_DATATYPE_FLOAT"]._serialized_end = 2180 + _globals["_DATATYPE_DOUBLE"]._serialized_start = 2182 + _globals["_DATATYPE_DOUBLE"]._serialized_end = 2248 + _globals["_DATATYPE_STRING"]._serialized_start = 2250 + _globals["_DATATYPE_STRING"]._serialized_end = 2346 + _globals["_DATATYPE_BINARY"]._serialized_start = 2348 + _globals["_DATATYPE_BINARY"]._serialized_end = 2414 + _globals["_DATATYPE_NULL"]._serialized_start = 2416 + _globals["_DATATYPE_NULL"]._serialized_end = 2480 + _globals["_DATATYPE_TIMESTAMP"]._serialized_start = 2482 + _globals["_DATATYPE_TIMESTAMP"]._serialized_end = 2551 + _globals["_DATATYPE_DATE"]._serialized_start = 2553 + _globals["_DATATYPE_DATE"]._serialized_end = 2617 + _globals["_DATATYPE_TIMESTAMPNTZ"]._serialized_start = 2619 + _globals["_DATATYPE_TIMESTAMPNTZ"]._serialized_end = 2691 + _globals["_DATATYPE_TIME"]._serialized_start = 2693 + _globals["_DATATYPE_TIME"]._serialized_end = 2806 + _globals["_DATATYPE_CALENDARINTERVAL"]._serialized_start = 2808 + _globals["_DATATYPE_CALENDARINTERVAL"]._serialized_end = 2884 + _globals["_DATATYPE_YEARMONTHINTERVAL"]._serialized_start = 2887 + _globals["_DATATYPE_YEARMONTHINTERVAL"]._serialized_end = 3066 + _globals["_DATATYPE_DAYTIMEINTERVAL"]._serialized_start = 3069 + _globals["_DATATYPE_DAYTIMEINTERVAL"]._serialized_end = 3246 + _globals["_DATATYPE_CHAR"]._serialized_start = 3248 + _globals["_DATATYPE_CHAR"]._serialized_end = 3336 + _globals["_DATATYPE_VARCHAR"]._serialized_start = 3338 + _globals["_DATATYPE_VARCHAR"]._serialized_end = 3429 + _globals["_DATATYPE_DECIMAL"]._serialized_start = 3432 + _globals["_DATATYPE_DECIMAL"]._serialized_end = 3585 + _globals["_DATATYPE_STRUCTFIELD"]._serialized_start = 3588 + _globals["_DATATYPE_STRUCTFIELD"]._serialized_end = 3749 + _globals["_DATATYPE_STRUCT"]._serialized_start = 3751 + _globals["_DATATYPE_STRUCT"]._serialized_end = 3878 + _globals["_DATATYPE_ARRAY"]._serialized_start = 3881 + _globals["_DATATYPE_ARRAY"]._serialized_end = 4043 + _globals["_DATATYPE_MAP"]._serialized_start = 4046 + _globals["_DATATYPE_MAP"]._serialized_end = 4265 + _globals["_DATATYPE_GEOMETRY"]._serialized_start = 4267 + _globals["_DATATYPE_GEOMETRY"]._serialized_end = 4355 + _globals["_DATATYPE_GEOGRAPHY"]._serialized_start = 4357 + _globals["_DATATYPE_GEOGRAPHY"]._serialized_end = 4446 + _globals["_DATATYPE_VARIANT"]._serialized_start = 4448 + _globals["_DATATYPE_VARIANT"]._serialized_end = 4515 + _globals["_DATATYPE_UDT"]._serialized_start = 4518 + _globals["_DATATYPE_UDT"]._serialized_end = 4807 + _globals["_DATATYPE_UNPARSED"]._serialized_start = 4809 + _globals["_DATATYPE_UNPARSED"]._serialized_end = 4861 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi index d46770c4f888e..3f625890a809b 100644 --- a/python/pyspark/sql/connect/proto/types_pb2.pyi +++ b/python/pyspark/sql/connect/proto/types_pb2.pyi @@ -674,6 +674,46 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... + class Geometry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SRID_FIELD_NUMBER: builtins.int + TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int + srid: builtins.int + type_variation_reference: builtins.int + def __init__( + self, + *, + srid: builtins.int = ..., + type_variation_reference: builtins.int = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "srid", b"srid", "type_variation_reference", b"type_variation_reference" + ], + ) -> None: ... + + class Geography(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SRID_FIELD_NUMBER: builtins.int + TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int + srid: builtins.int + type_variation_reference: builtins.int + def __init__( + self, + *, + srid: builtins.int = ..., + type_variation_reference: builtins.int = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "srid", b"srid", "type_variation_reference", b"type_variation_reference" + ], + ) -> None: ... + class Variant(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -821,6 +861,8 @@ class DataType(google.protobuf.message.Message): MAP_FIELD_NUMBER: builtins.int VARIANT_FIELD_NUMBER: builtins.int UDT_FIELD_NUMBER: builtins.int + GEOMETRY_FIELD_NUMBER: builtins.int + GEOGRAPHY_FIELD_NUMBER: builtins.int UNPARSED_FIELD_NUMBER: builtins.int TIME_FIELD_NUMBER: builtins.int @property @@ -878,6 +920,11 @@ class DataType(google.protobuf.message.Message): def udt(self) -> global___DataType.UDT: """UserDefinedType""" @property + def geometry(self) -> global___DataType.Geometry: + """Geospatial types""" + @property + def geography(self) -> global___DataType.Geography: ... + @property def unparsed(self) -> global___DataType.Unparsed: """UnparsedDataType""" @property @@ -909,6 +956,8 @@ class DataType(google.protobuf.message.Message): map: global___DataType.Map | None = ..., variant: global___DataType.Variant | None = ..., udt: global___DataType.UDT | None = ..., + geometry: global___DataType.Geometry | None = ..., + geography: global___DataType.Geography | None = ..., unparsed: global___DataType.Unparsed | None = ..., time: global___DataType.Time | None = ..., ) -> None: ... @@ -937,6 +986,10 @@ class DataType(google.protobuf.message.Message): b"double", "float", b"float", + "geography", + b"geography", + "geometry", + b"geometry", "integer", b"integer", "kind", @@ -996,6 +1049,10 @@ class DataType(google.protobuf.message.Message): b"double", "float", b"float", + "geography", + b"geography", + "geometry", + b"geometry", "integer", b"integer", "kind", @@ -1058,6 +1115,8 @@ class DataType(google.protobuf.message.Message): "map", "variant", "udt", + "geometry", + "geography", "unparsed", "time", ] diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 7e8f768610794..d3352b618d7c7 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -50,6 +50,8 @@ NullType, NumericType, VariantType, + GeographyType, + GeometryType, UserDefinedType, ) from pyspark.errors import PySparkAssertionError, PySparkValueError @@ -191,6 +193,10 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: ret.array.contains_null = data_type.containsNull elif isinstance(data_type, VariantType): ret.variant.CopyFrom(pb2.DataType.Variant()) + elif isinstance(data_type, GeometryType): + ret.geometry.srid = data_type.srid + elif isinstance(data_type, GeographyType): + ret.geography.srid = data_type.srid elif isinstance(data_type, UserDefinedType): json_value = data_type.jsonValue() ret.udt.type = "udt" @@ -303,6 +309,18 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: ) elif schema.HasField("variant"): return VariantType() + elif schema.HasField("geometry"): + srid = schema.geometry.srid + if srid == GeometryType.MIXED_SRID: + return GeometryType("ANY") + else: + return GeometryType(srid) + elif schema.HasField("geography"): + srid = schema.geography.srid + if srid == GeographyType.MIXED_SRID: + return GeographyType("ANY") + else: + return GeographyType(srid) elif schema.HasField("udt"): assert schema.udt.type == "udt" json_value = {} diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index a8f621277a0af..f73727d1d5344 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -28,6 +28,10 @@ BinaryType, DataType, DecimalType, + GeographyType, + Geography, + GeometryType, + Geometry, MapType, NullType, Row, @@ -89,6 +93,10 @@ def _need_converter( return True elif isinstance(dataType, VariantType): return True + elif isinstance(dataType, GeometryType): + return True + elif isinstance(dataType, GeographyType): + return True else: return False @@ -392,6 +400,34 @@ def convert_variant(value: Any) -> Any: return convert_variant + elif isinstance(dataType, GeographyType): + + def convert_geography(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + elif isinstance(value, Geography): + return dataType.toInternal(value) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY") + + return convert_geography + + elif isinstance(dataType, GeometryType): + + def convert_geometry(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + elif isinstance(value, Geometry): + return dataType.toInternal(value) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOMETRY") + + return convert_geometry + elif not nullable: def convert_other(value: Any) -> Any: @@ -511,6 +547,10 @@ def _need_converter(dataType: DataType) -> bool: return True elif isinstance(dataType, VariantType): return True + elif isinstance(dataType, GeographyType): + return True + elif isinstance(dataType, GeometryType): + return True else: return False @@ -719,6 +759,40 @@ def convert_variant(value: Any) -> Any: return convert_variant + elif isinstance(dataType, GeographyType): + + def convert_geography(value: Any) -> Any: + if value is None: + return None + elif ( + isinstance(value, dict) + and all(key in value for key in ["wkb", "srid"]) + and isinstance(value["wkb"], bytes) + and isinstance(value["srid"], int) + ): + return Geography.fromWKB(value["wkb"], value["srid"]) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY") + + return convert_geography + + elif isinstance(dataType, GeometryType): + + def convert_geometry(value: Any) -> Any: + if value is None: + return None + elif ( + isinstance(value, dict) + and all(key in value for key in ["wkb", "srid"]) + and isinstance(value["wkb"], bytes) + and isinstance(value["srid"], int) + ): + return Geometry.fromWKB(value["wkb"], value["srid"]) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOMETRY") + + return convert_geometry + else: if none_on_identity: return None diff --git a/python/pyspark/sql/tests/connect/test_parity_geographytype.py b/python/pyspark/sql/tests/connect/test_parity_geographytype.py new file mode 100644 index 0000000000000..501bbed20ff19 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_geographytype.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.test_geographytype import GeographyTypeTestMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class GeographyTypeParityTest(GeographyTypeTestMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_geographytype import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_geometrytype.py b/python/pyspark/sql/tests/connect/test_parity_geometrytype.py new file mode 100644 index 0000000000000..b95321b3c61be --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_geometrytype.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.test_geometrytype import GeometryTypeTestMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class GeometryTypeParityTest(GeometryTypeTestMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_geometrytype import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/sql/connect/common/src/main/protobuf/spark/connect/types.proto b/sql/connect/common/src/main/protobuf/spark/connect/types.proto index 1800e3885774f..caaa2340f95dd 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/types.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/types.proto @@ -67,15 +67,17 @@ message DataType { // UserDefinedType UDT udt = 23; + // Geospatial types + Geometry geometry = 26; + + Geography geography = 27; + // UnparsedDataType Unparsed unparsed = 24; Time time = 28; } - // Reserved for geometry and geography types - reserved 26, 27; - message Boolean { uint32 type_variation_reference = 1; } @@ -192,6 +194,16 @@ message DataType { uint32 type_variation_reference = 4; } + message Geometry { + int32 srid = 1; + uint32 type_variation_reference = 2; + } + + message Geography { + int32 srid = 1; + uint32 type_variation_reference = 2; + } + message Variant { uint32 type_variation_reference = 1; } From 60bd30fb76f5a09b71ff82ef9fe96c16351d0b70 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Tue, 4 Nov 2025 17:40:17 +0100 Subject: [PATCH 2/5] Add Geography and Geometry classes --- python/pyspark/sql/types.py | 172 ++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 440100dba9312..a4d11d3ff063d 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -90,6 +90,8 @@ "TimestampNTZType", "DecimalType", "DoubleType", + "Geography", + "Geometry", "FloatType", "ByteType", "IntegerType", @@ -616,6 +618,20 @@ def jsonValue(self) -> Union[str, Dict[str, Any]]: # The JSON representation always uses the CRS and algorithm value. return f"geography({self._crs}, {self._alg})" + def needConversion(self) -> bool: + return True + + def fromInternal(self, obj: Dict) -> Optional["Geography"]: + if obj is None or not all(key in obj for key in ["srid", "bytes"]): + return None + return Geography(obj["bytes"], obj["srid"]) + + def toInternal(self, geography: Any) -> Any: + if geography is None: + return None + assert isinstance(geography, Geography) + return {"srid": geography.srid, "wkb": geography.wkb} + class GeometryType(SpatialType): """ @@ -700,6 +716,20 @@ def jsonValue(self) -> Union[str, Dict[str, Any]]: # The JSON representation always uses the CRS value. return f"geometry({self._crs})" + def needConversion(self) -> bool: + return True + + def fromInternal(self, obj: Dict) -> Optional["Geometry"]: + if obj is None or not all(key in obj for key in ["srid", "bytes"]): + return None + return Geometry(obj["bytes"], obj["srid"]) + + def toInternal(self, geometry: Any) -> Any: + if geometry is None: + return None + assert isinstance(geometry, Geometry) + return {"srid": geometry.srid, "wkb": geometry.wkb} + class ByteType(IntegralType): """Byte data type, representing signed 8-bit integers.""" @@ -2039,6 +2069,148 @@ def parseJson(cls, json_str: str) -> "VariantVal": return VariantVal(value, metadata) +class Geography: + """ + A class to represent a Geography value in Python. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + wkb : bytes + The bytes representing the WKB of Geography. + + srid : integer + The integer value representing SRID of Geography. + + Methods + ------- + getBytes() + Returns the WKB of Geography. + + getSrid() + Returns the SRID of Geography. + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(bytes.fromhex("010100000000000000000031400000000000001C40"),)], ["wkb"],) # noqa + >>> g = df.select(sf.geogfromwkb(df.wkb).alias("geog")).head().geog # doctest: +SKIP + >>> g.getBytes() # doctest: +SKIP + b'\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x001@\x00\x00\x00\x00\x00\x00\x1c@' + >>> g.getSrid() # doctest: +SKIP + 4326 + """ + + def __init__(self, wkb: bytes, srid: int): + self.wkb = wkb + self.srid = srid + + def __str__(self) -> str: + return "Geography(%r, %d)" % (self.wkb, self.srid) + + def __repr__(self) -> str: + return "Geography(%r, %d)" % (self.wkb, self.srid) + + def getSrid(self) -> int: + """ + Returns the SRID of Geography. + """ + return self.srid + + def getBytes(self) -> bytes: + """ + Returns the WKB of Geography. + """ + return self.wkb + + def __eq__(self, other): + if not isinstance(other, Geography): + # Don't attempt to compare against unrelated types. + return NotImplemented + + return self.wkb == other.wkb and self.srid == other.srid + + @classmethod + def fromWKB(cls, wkb: bytes, srid: int) -> "Geography": + """ + Construct Python Geography object from WKB. + :return: Python representation of the Geography type value. + """ + return Geography(wkb, srid) + + +class Geometry: + """ + A class to represent a Geometry value in Python. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + wkb : bytes + The bytes representing the WKB of Geometry. + + srid : integer + The integer value representing SRID of Geometry. + + Methods + ------- + getBytes() + Returns the WKB of Geometry. + + getSrid() + Returns the SRID of Geometry. + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(bytes.fromhex("010100000000000000000031400000000000001C40"),)], ["wkb"],) # noqa + >>> g = df.select(sf.geomfromwkb(df.geomwkb).alias("geom")).head().geom # doctest: +SKIP + >>> g.getBytes() # doctest: +SKIP + b'\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x001@\x00\x00\x00\x00\x00\x00\x1c@' + >>> g.getSrid() # doctest: +SKIP + 0 + """ + + def __init__(self, wkb: bytes, srid: int): + self.wkb = wkb + self.srid = srid + + def __str__(self) -> str: + return "Geometry(%r, %d)" % (self.wkb, self.srid) + + def __repr__(self) -> str: + return "Geometry(%r, %d)" % (self.wkb, self.srid) + + def getSrid(self) -> int: + """ + Returns the SRID of Geometry. + """ + return self.srid + + def getBytes(self) -> bytes: + """ + Returns the WKB of Geometry. + """ + return self.wkb + + def __eq__(self, other): + if not isinstance(other, Geometry): + # Don't attempt to compare against unrelated types. + return NotImplemented + + return self.wkb == other.wkb and self.srid == other.srid + + @classmethod + def fromWKB(cls, wkb: bytes, srid: int) -> "Geometry": + """ + Construct Python Geometry object from WKB. + :return: Python representation of the Geometry type value. + """ + return Geometry(wkb, srid) + + _atomic_types: List[Type[DataType]] = [ StringType, CharType, From db24ae60e2f354f40e288bd21190c4622f4212bd Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Wed, 5 Nov 2025 08:16:58 +0100 Subject: [PATCH 3/5] Fix mypy issues --- python/pyspark/sql/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index a4d11d3ff063d..93fe1a74b1292 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -2124,7 +2124,7 @@ def getBytes(self) -> bytes: """ return self.wkb - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, Geography): # Don't attempt to compare against unrelated types. return NotImplemented @@ -2195,7 +2195,7 @@ def getBytes(self) -> bytes: """ return self.wkb - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, Geometry): # Don't attempt to compare against unrelated types. return NotImplemented From 1ac1522a6856480218b2e559447ee815c4cba4b2 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Wed, 5 Nov 2025 11:30:58 +0100 Subject: [PATCH 4/5] Fix doctest examples --- python/pyspark/sql/types.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 93fe1a74b1292..8e820f9d3f89c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -2094,11 +2094,11 @@ class Geography: Examples -------- >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([(bytes.fromhex("010100000000000000000031400000000000001C40"),)], ["wkb"],) # noqa - >>> g = df.select(sf.geogfromwkb(df.wkb).alias("geog")).head().geog # doctest: +SKIP - >>> g.getBytes() # doctest: +SKIP - b'\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x001@\x00\x00\x00\x00\x00\x00\x1c@' - >>> g.getSrid() # doctest: +SKIP + >>> df = spark.createDataFrame([(bytes.fromhex("010100000000000000000031400000000000001c40"),)], ["wkb"],) # noqa + >>> g = df.select(sf.st_geogfromwkb(df.wkb).alias("geog")).head().geog + >>> g.getBytes().hex() + '010100000000000000000031400000000000001c40' + >>> g.getSrid() 4326 """ @@ -2165,11 +2165,11 @@ class Geometry: Examples -------- >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([(bytes.fromhex("010100000000000000000031400000000000001C40"),)], ["wkb"],) # noqa - >>> g = df.select(sf.geomfromwkb(df.geomwkb).alias("geom")).head().geom # doctest: +SKIP - >>> g.getBytes() # doctest: +SKIP - b'\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x001@\x00\x00\x00\x00\x00\x00\x1c@' - >>> g.getSrid() # doctest: +SKIP + >>> df = spark.createDataFrame([(bytes.fromhex("010100000000000000000031400000000000001c40"),)], ["wkb"],) # noqa + >>> g = df.select(sf.st_geomfromwkb(df.wkb).alias("geom")).head().geom + >>> g.getBytes().hex() + '010100000000000000000031400000000000001c40' + >>> g.getSrid() 0 """ From 30884ba52d8fc2a46567cbda5bf69e6204ee0ccf Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Wed, 5 Nov 2025 16:53:26 +0100 Subject: [PATCH 5/5] Simplify doctest examples --- python/pyspark/sql/types.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 8e820f9d3f89c..8aae398800727 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -2093,9 +2093,7 @@ class Geography: Examples -------- - >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([(bytes.fromhex("010100000000000000000031400000000000001c40"),)], ["wkb"],) # noqa - >>> g = df.select(sf.st_geogfromwkb(df.wkb).alias("geog")).head().geog + >>> g = Geography.fromWKB(bytes.fromhex('010100000000000000000031400000000000001c40'), 4326) >>> g.getBytes().hex() '010100000000000000000031400000000000001c40' >>> g.getSrid() @@ -2164,9 +2162,7 @@ class Geometry: Examples -------- - >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([(bytes.fromhex("010100000000000000000031400000000000001c40"),)], ["wkb"],) # noqa - >>> g = df.select(sf.st_geomfromwkb(df.wkb).alias("geom")).head().geom + >>> g = Geometry.fromWKB(bytes.fromhex('010100000000000000000031400000000000001c40'), 0) >>> g.getBytes().hex() '010100000000000000000031400000000000001c40' >>> g.getSrid()