diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index ec4490e845ae..6c0facbfeee8 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -43,7 +43,11 @@ message Expression { Expression expr = 1; // (Required) the data type that the expr to be casted to. - DataType cast_to_type = 2; + oneof cast_to_type { + DataType type = 2; + // If this is set, Server will use Catalyst parser to parse this string to DataType. + string type_str = 3; + } } message Literal { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 44baf4078164..fb79243ba379 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -96,7 +96,17 @@ package object dsl { Expression.Cast .newBuilder() .setExpr(expr) - .setCastToType(dataType)) + .setType(dataType)) + .build() + + def cast(dataType: String): Expression = + Expression + .newBuilder() + .setCast( + Expression.Cast + .newBuilder() + .setExpr(expr) + .setTypeStr(dataType)) .build() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index af5d9abc5154..55283ca96b13 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -518,9 +518,16 @@ class SparkConnectPlanner(session: SparkSession) { } private def transformCast(cast: proto.Expression.Cast): Expression = { - Cast( - transformExpression(cast.getExpr), - DataTypeProtoConverter.toCatalystType(cast.getCastToType)) + cast.getCastToTypeCase match { + case proto.Expression.Cast.CastToTypeCase.TYPE => + Cast( + transformExpression(cast.getExpr), + DataTypeProtoConverter.toCatalystType(cast.getType)) + case _ => + Cast( + transformExpression(cast.getExpr), + session.sessionState.sqlParser.parseDataType(cast.getTypeStr)) + } } private def transformSetOperation(u: proto.SetOperation): LogicalPlan = { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 6d36ea9a6305..c04d7bde7461 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -550,6 +550,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { connectTestRelation.select("id".protoAttr.cast( proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build())), sparkTestRelation.select(col("id").cast(StringType))) + + comparePlans( + connectTestRelation.select("id".protoAttr.cast("string")), + sparkTestRelation.select(col("id").cast("string"))) } test("Test Hint") { diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 745ca79fda91..790f9980d5df 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -26,30 +26,13 @@ import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib +import pyspark.sql.connect.types as types import pyspark.sql.types from pyspark import cloudpickle from pyspark.sql.types import ( DataType, - ByteType, - ShortType, - IntegerType, - FloatType, - DateType, - TimestampType, - DayTimeIntervalType, - MapType, - StringType, - CharType, - VarcharType, StructType, StructField, - ArrayType, - DoubleType, - LongType, - DecimalType, - BinaryType, - BooleanType, - NullType, ) @@ -350,73 +333,7 @@ def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame": return self._execute_and_fetch(req) def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType: - if schema.HasField("null"): - return NullType() - elif schema.HasField("boolean"): - return BooleanType() - elif schema.HasField("binary"): - return BinaryType() - elif schema.HasField("byte"): - return ByteType() - elif schema.HasField("short"): - return ShortType() - elif schema.HasField("integer"): - return IntegerType() - elif schema.HasField("long"): - return LongType() - elif schema.HasField("float"): - return FloatType() - elif schema.HasField("double"): - return DoubleType() - elif schema.HasField("decimal"): - p = schema.decimal.precision if schema.decimal.HasField("precision") else 10 - s = schema.decimal.scale if schema.decimal.HasField("scale") else 0 - return DecimalType(precision=p, scale=s) - elif schema.HasField("string"): - return StringType() - elif schema.HasField("char"): - return CharType(schema.char.length) - elif schema.HasField("var_char"): - return VarcharType(schema.var_char.length) - elif schema.HasField("date"): - return DateType() - elif schema.HasField("timestamp"): - return TimestampType() - elif schema.HasField("day_time_interval"): - start: Optional[int] = ( - schema.day_time_interval.start_field - if schema.day_time_interval.HasField("start_field") - else None - ) - end: Optional[int] = ( - schema.day_time_interval.end_field - if schema.day_time_interval.HasField("end_field") - else None - ) - return DayTimeIntervalType(startField=start, endField=end) - elif schema.HasField("array"): - return ArrayType( - self._proto_schema_to_pyspark_schema(schema.array.element_type), - schema.array.contains_null, - ) - elif schema.HasField("struct"): - fields = [ - StructField( - f.name, - self._proto_schema_to_pyspark_schema(f.data_type), - f.nullable, - ) - for f in schema.struct.fields - ] - return StructType(fields) - elif schema.HasField("map"): - return MapType( - self._proto_schema_to_pyspark_schema(schema.map.key_type), - self._proto_schema_to_pyspark_schema(schema.map.value_type), - schema.map.value_contains_null, - ) - else: - raise Exception(f"Unsupported data type {schema}") + return types.proto_schema_to_pyspark_data_type(schema) def schema(self, plan: pb2.Plan) -> StructType: proto_schema = self._analyze(plan).schema diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index e864f6c93e3b..63e95c851db8 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -21,9 +21,10 @@ import decimal import datetime -from pyspark.sql.types import TimestampType, DayTimeIntervalType, DateType +from pyspark.sql.types import TimestampType, DayTimeIntervalType, DataType, DateType import pyspark.sql.connect.proto as proto +from pyspark.sql.connect.types import pyspark_types_to_proto_types if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName @@ -355,6 +356,29 @@ def __repr__(self) -> str: return f"{self._name}({', '.join([str(arg) for arg in self._args])})" +class CastExpression(Expression): + def __init__( + self, + col: "Column", + data_type: Union[DataType, str], + ) -> None: + super().__init__() + self._col = col + self._data_type = data_type + + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: + fun = proto.Expression() + fun.cast.expr.CopyFrom(self._col.to_plan(session)) + if isinstance(self._data_type, str): + fun.cast.type_str = self._data_type + else: + fun.cast.type.CopyFrom(pyspark_types_to_proto_types(self._data_type)) + return fun + + def __repr__(self) -> str: + return f"({self._col} ({self._data_type}))" + + class Column: """ A column in a DataFrame. Column can refer to different things based on the @@ -733,6 +757,28 @@ def desc_nulls_last(self) -> "Column": def name(self) -> str: return self._expr.name() + def cast(self, dataType: Union[DataType, str]) -> "Column": + """ + Casts the column into type ``dataType``. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + dataType : :class:`DataType` or str + a DataType or Python string literal with a DDL-formatted string + to use when parsing the column to the same type. + + Returns + ------- + :class:`Column` + Column representing whether each element of Column is cast into new type. + """ + if isinstance(dataType, (DataType, str)): + return Column(CastExpression(col=self, data_type=dataType)) + else: + raise TypeError("unexpected type: %s" % type(dataType)) + # TODO(SPARK-41329): solve the circular import between functions.py and # this class if we want to reuse functions.lit def _lit(self, x: Any) -> "Column": diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 8510216324d3..91c57a9ef220 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xd2\x14\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x1ap\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x39\n\x0c\x63\x61st_to_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\ncastToType\x1a\xb2\x0b\n\x07Literal\x12\x14\n\x04null\x18\x01 \x01(\x08H\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x42\n\x06struct\x18\x17 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x18 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x42\n\x05\x41rray\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\xbd\x01\n\x03Map\x12@\n\x05pairs\x18\x01 \x03(\x0b\x32*.spark.connect.Expression.Literal.Map.PairR\x05pairs\x1at\n\x04Pair\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05valueB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadataB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xf4\x14\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xb2\x0b\n\x07Literal\x12\x14\n\x04null\x18\x01 \x01(\x08H\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x42\n\x06struct\x18\x17 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x18 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x42\n\x05\x41rray\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\xbd\x01\n\x03Map\x12@\n\x05pairs\x18\x01 \x03(\x0b\x32*.spark.connect.Expression.Literal.Map.PairR\x05pairs\x1at\n\x04Pair\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05valueB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadataB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -197,31 +197,31 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 78 - _EXPRESSION._serialized_end = 2720 - _EXPRESSION_CAST._serialized_start = 639 - _EXPRESSION_CAST._serialized_end = 751 - _EXPRESSION_LITERAL._serialized_start = 754 - _EXPRESSION_LITERAL._serialized_end = 2212 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1650 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1767 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1769 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1867 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 1869 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 1936 - _EXPRESSION_LITERAL_ARRAY._serialized_start = 1938 - _EXPRESSION_LITERAL_ARRAY._serialized_end = 2004 - _EXPRESSION_LITERAL_MAP._serialized_start = 2007 - _EXPRESSION_LITERAL_MAP._serialized_end = 2196 - _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 2080 - _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2196 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2214 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2284 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2287 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2491 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2493 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2543 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2545 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2585 - _EXPRESSION_ALIAS._serialized_start = 2587 - _EXPRESSION_ALIAS._serialized_end = 2707 + _EXPRESSION._serialized_end = 2754 + _EXPRESSION_CAST._serialized_start = 640 + _EXPRESSION_CAST._serialized_end = 785 + _EXPRESSION_LITERAL._serialized_start = 788 + _EXPRESSION_LITERAL._serialized_end = 2246 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1684 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1801 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1803 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1901 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 1903 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 1970 + _EXPRESSION_LITERAL_ARRAY._serialized_start = 1972 + _EXPRESSION_LITERAL_ARRAY._serialized_end = 2038 + _EXPRESSION_LITERAL_MAP._serialized_start = 2041 + _EXPRESSION_LITERAL_MAP._serialized_end = 2230 + _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 2114 + _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2230 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2248 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2318 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2321 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2525 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2527 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2577 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2579 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2619 + _EXPRESSION_ALIAS._serialized_start = 2621 + _EXPRESSION_ALIAS._serialized_end = 2741 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index c1034a863601..2c486f62a9dd 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -60,27 +60,51 @@ class Expression(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor EXPR_FIELD_NUMBER: builtins.int - CAST_TO_TYPE_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + TYPE_STR_FIELD_NUMBER: builtins.int @property def expr(self) -> global___Expression: """(Required) the expression to be casted.""" @property - def cast_to_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: - """(Required) the data type that the expr to be casted to.""" + def type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... + type_str: builtins.str + """If this is set, Server will use Catalyst parser to parse this string to DataType.""" def __init__( self, *, expr: global___Expression | None = ..., - cast_to_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + type_str: builtins.str = ..., ) -> None: ... def HasField( self, - field_name: typing_extensions.Literal["cast_to_type", b"cast_to_type", "expr", b"expr"], + field_name: typing_extensions.Literal[ + "cast_to_type", + b"cast_to_type", + "expr", + b"expr", + "type", + b"type", + "type_str", + b"type_str", + ], ) -> builtins.bool: ... def ClearField( self, - field_name: typing_extensions.Literal["cast_to_type", b"cast_to_type", "expr", b"expr"], + field_name: typing_extensions.Literal[ + "cast_to_type", + b"cast_to_type", + "expr", + b"expr", + "type", + b"type", + "type_str", + b"type_str", + ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["cast_to_type", b"cast_to_type"] + ) -> typing_extensions.Literal["type", "type_str"] | None: ... class Literal(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py new file mode 100644 index 000000000000..55f595366079 --- /dev/null +++ b/python/pyspark/sql/connect/types.py @@ -0,0 +1,143 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +import pyspark.sql.connect.proto as pb2 +from pyspark.sql.types import ( + DataType, + ByteType, + ShortType, + IntegerType, + FloatType, + DateType, + TimestampType, + DayTimeIntervalType, + MapType, + StringType, + CharType, + VarcharType, + StructType, + StructField, + ArrayType, + DoubleType, + LongType, + DecimalType, + BinaryType, + BooleanType, + NullType, +) + + +def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: + ret = pb2.DataType() + if isinstance(data_type, StringType): + ret.string.CopyFrom(pb2.DataType.String()) + elif isinstance(data_type, BooleanType): + ret.boolean.CopyFrom(pb2.DataType.Boolean()) + elif isinstance(data_type, BinaryType): + ret.binary.CopyFrom(pb2.DataType.Binary()) + elif isinstance(data_type, ByteType): + ret.byte.CopyFrom(pb2.DataType.Byte()) + elif isinstance(data_type, ShortType): + ret.short.CopyFrom(pb2.DataType.Short()) + elif isinstance(data_type, IntegerType): + ret.integer.CopyFrom(pb2.DataType.Integer()) + elif isinstance(data_type, LongType): + ret.long.CopyFrom(pb2.DataType.Long()) + elif isinstance(data_type, FloatType): + ret.float.CopyFrom(pb2.DataType.Float()) + elif isinstance(data_type, DoubleType): + ret.double.CopyFrom(pb2.DataType.Double()) + elif isinstance(data_type, DecimalType): + ret.decimal.CopyFrom(pb2.DataType.Decimal()) + elif isinstance(data_type, DayTimeIntervalType): + ret.day_time_interval.start_field = data_type.startField + ret.day_time_interval.end_field = data_type.endField + else: + raise Exception(f"Unsupported data type {data_type}") + return ret + + +def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: + if schema.HasField("null"): + return NullType() + elif schema.HasField("boolean"): + return BooleanType() + elif schema.HasField("binary"): + return BinaryType() + elif schema.HasField("byte"): + return ByteType() + elif schema.HasField("short"): + return ShortType() + elif schema.HasField("integer"): + return IntegerType() + elif schema.HasField("long"): + return LongType() + elif schema.HasField("float"): + return FloatType() + elif schema.HasField("double"): + return DoubleType() + elif schema.HasField("decimal"): + p = schema.decimal.precision if schema.decimal.HasField("precision") else 10 + s = schema.decimal.scale if schema.decimal.HasField("scale") else 0 + return DecimalType(precision=p, scale=s) + elif schema.HasField("string"): + return StringType() + elif schema.HasField("char"): + return CharType(schema.char.length) + elif schema.HasField("var_char"): + return VarcharType(schema.var_char.length) + elif schema.HasField("date"): + return DateType() + elif schema.HasField("timestamp"): + return TimestampType() + elif schema.HasField("day_time_interval"): + start: Optional[int] = ( + schema.day_time_interval.start_field + if schema.day_time_interval.HasField("start_field") + else None + ) + end: Optional[int] = ( + schema.day_time_interval.end_field + if schema.day_time_interval.HasField("end_field") + else None + ) + return DayTimeIntervalType(startField=start, endField=end) + elif schema.HasField("array"): + return ArrayType( + proto_schema_to_pyspark_data_type(schema.array.element_type), + schema.array.contains_null, + ) + elif schema.HasField("struct"): + fields = [ + StructField( + f.name, + proto_schema_to_pyspark_data_type(f.data_type), + f.nullable, + ) + for f in schema.struct.fields + ] + return StructType(fields) + elif schema.HasField("map"): + return MapType( + proto_schema_to_pyspark_data_type(schema.map.key_type), + proto_schema_to_pyspark_data_type(schema.map.value_type), + schema.map.value_contains_null, + ) + else: + raise Exception(f"Unsupported data type {schema}") diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 106ab609bfa3..c73f1b5b0c75 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -16,7 +16,21 @@ # from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase +from pyspark.sql.types import StringType from pyspark.testing.sqlutils import have_pandas +from pyspark.sql.types import ( + ByteType, + ShortType, + IntegerType, + FloatType, + DayTimeIntervalType, + StringType, + DoubleType, + LongType, + DecimalType, + BinaryType, + BooleanType, +) if have_pandas: from pyspark.sql.connect.functions import lit @@ -80,6 +94,31 @@ def test_simple_binary_expressions(self): res = pandas.DataFrame(data={"id": [0, 30, 60, 90]}) self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}") + def test_cast(self): + df = self.connect.read.table(self.tbl_name) + df2 = self.spark.read.table(self.tbl_name) + + self.assert_eq( + df.select(df.id.cast("string")).toPandas(), df2.select(df2.id.cast("string")).toPandas() + ) + + for x in [ + StringType(), + BinaryType(), + ShortType(), + IntegerType(), + LongType(), + FloatType(), + DoubleType(), + ByteType(), + DecimalType(10, 2), + BooleanType(), + DayTimeIntervalType(), + ]: + self.assert_eq( + df.select(df.id.cast(x)).toPandas(), df2.select(df2.id.cast(x)).toPandas() + ) + if __name__ == "__main__": import unittest