diff --git a/connector/connect/README.md b/connector/connect/README.md index faad529c515c..8e53d0d6c692 100644 --- a/connector/connect/README.md +++ b/connector/connect/README.md @@ -70,3 +70,4 @@ When contributing a new client please be aware that we strive to have a common user experience across all languages. Please follow the below guidelines: * [Connection string configuration](docs/client-connection-string.md) +* [Adding new messages](docs/adding-proto-messages.md) in the Spark Connect protocol. diff --git a/connector/connect/docs/adding-proto-messages.md b/connector/connect/docs/adding-proto-messages.md new file mode 100644 index 000000000000..85e7bb79e0a3 --- /dev/null +++ b/connector/connect/docs/adding-proto-messages.md @@ -0,0 +1,40 @@ +# Required, Optional and default values + +Spark Connect adopts proto3, which does not support the use of the `required` constraint anymore. +For non-message proto fields, there is also no `has_field_name` functions to easy tell +if a filed is set or not-set. (Read [proto3 field rules](https://developers.google.com/protocol-buffers/docs/proto3#specifying_field_rules)) + + +### Required field + +When adding fields that have required semantics, developers are required to follow +the outlined process. Fields that are semantically required for the server to +correctly process the incoming message must be documented with `(Required)`. For scalar +fields the server will not perform any additional input validation. For compound fields, +the server will perform minimal checks to avoid null pointer exceptions but will not +perform any semantic validation. + +Example: +```protobuf +message DataSource { + // (Required) Supported formats include: parquet, orc, text, json, parquet, csv, avro. + string format = 1; +} +``` + + +### Optional fields + +Semantically optional fields must be marked by `optional`. The server side will +then use this information to branch into different behaviors based on the presence or absence of this field. + +Due to the lack of configurable default values for scalar types, the pure presence of +an optional value does not define its default value. The server side implementation will interpret the observed value based on its own rules. + +Example: +```protobuf +message DataSource { + // (Optional) If not set, Spark will infer the schema. + optional string schema = 2; +} +``` diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto index 426e04341d9f..ac5fe24d349e 100644 --- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto @@ -170,6 +170,8 @@ message Expression { message Alias { Expression expr = 1; - string name = 2; + repeated string name = 2; + // Alias metadata expressed as a JSON map. + optional string metadata = 3; } } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 98879b69b848..caff3d8f0713 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -79,7 +79,28 @@ package object dsl { implicit class DslExpression(val expr: Expression) { def as(alias: String): Expression = Expression .newBuilder() - .setAlias(Expression.Alias.newBuilder().setName(alias).setExpr(expr)) + .setAlias(Expression.Alias.newBuilder().addName(alias).setExpr(expr)) + .build() + + def as(alias: String, metadata: String): Expression = Expression + .newBuilder() + .setAlias( + Expression.Alias + .newBuilder() + .setExpr(expr) + .addName(alias) + .setMetadata(metadata) + .build()) + .build() + + def as(alias: Seq[String]): Expression = Expression + .newBuilder() + .setAlias( + Expression.Alias + .newBuilder() + .setExpr(expr) + .addAllName(alias.asJava) + .build()) .build() def <(other: Expression): Expression = @@ -101,6 +122,13 @@ package object dsl { Expression.UnresolvedFunction.newBuilder().addParts("min").addArguments(e)) .build() + def proto_explode(e: Expression): Expression = + Expression + .newBuilder() + .setUnresolvedFunction( + Expression.UnresolvedFunction.newBuilder().addParts("explode").addArguments(e)) + .build() + /** * Create an unresolved function from name parts. * diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala index 0ee90b5e8fbb..088030b2dbcf 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala @@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._ import org.apache.spark.connect.proto import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, MapType, StringType, StructField, StructType} /** * This object offers methods to convert to/from connect proto to catalyst types. @@ -32,6 +32,7 @@ object DataTypeProtoConverter { case proto.DataType.KindCase.I32 => IntegerType case proto.DataType.KindCase.STRING => StringType case proto.DataType.KindCase.STRUCT => convertProtoDataTypeToCatalyst(t.getStruct) + case proto.DataType.KindCase.MAP => convertProtoDataTypeToCatalyst(t.getMap) case _ => throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.") } @@ -44,6 +45,10 @@ object DataTypeProtoConverter { StructType.apply(structFields) } + private def convertProtoDataTypeToCatalyst(t: proto.DataType.Map): MapType = { + MapType(toCatalystType(t.getKey), toCatalystType(t.getValue)) + } + def toConnectProtoType(t: DataType): proto.DataType = { t match { case IntegerType => @@ -54,11 +59,24 @@ object DataTypeProtoConverter { proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build() case struct: StructType => toConnectProtoStructType(struct) + case map: MapType => toConnectProtoMapType(map) case _ => throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.") } } + def toConnectProtoMapType(schema: MapType): proto.DataType = { + proto.DataType + .newBuilder() + .setMap( + proto.DataType.Map + .newBuilder() + .setKey(toConnectProtoType(schema.keyType)) + .setValue(toConnectProtoType(schema.valueType)) + .build()) + .build() + } + def toConnectProtoStructType(schema: StructType): proto.DataType = { val struct = proto.DataType.Struct.newBuilder() for (structField <- schema.fields) { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index f5801d57c29f..232f6e10474d 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -27,7 +27,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.WriteOperation import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.AliasIdentifier -import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression} import org.apache.spark.sql.catalyst.optimizer.CombineUnions @@ -338,7 +338,9 @@ class SparkConnectPlanner(session: SparkSession) { case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias) case proto.Expression.ExprTypeCase.EXPRESSION_STRING => transformExpressionString(exp.getExpressionString) - case _ => throw InvalidPlanInput() + case _ => + throw InvalidPlanInput( + s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported") } } @@ -412,7 +414,20 @@ class SparkConnectPlanner(session: SparkSession) { } private def transformAlias(alias: proto.Expression.Alias): NamedExpression = { - Alias(transformExpression(alias.getExpr), alias.getName)() + if (alias.getNameCount == 1) { + val md = if (alias.hasMetadata()) { + Some(Metadata.fromJson(alias.getMetadata)) + } else { + None + } + Alias(transformExpression(alias.getExpr), alias.getName(0))(explicitMetadata = md) + } else { + if (alias.hasMetadata) { + throw new InvalidPlanInput( + "Alias expressions with more than 1 identifier must not use optional metadata.") + } + MultiAlias(transformExpression(alias.getExpr), alias.getNameList.asScala.toSeq) + } } private def transformExpressionString(expr: proto.Expression.ExpressionString): Expression = { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 71633830b562..404581445d0c 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.connect.dsl.commands._ import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, MapType, Metadata, StringType, StructField, StructType} /** * This suite is based on connect DSL and test that given same dataframe operations, whether @@ -50,6 +50,9 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { createLocalRelationProto( Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)())) + lazy val connectTestRelationMap = + createLocalRelationProto(Seq(AttributeReference("id", MapType(StringType, StringType))())) + lazy val sparkTestRelation: DataFrame = spark.createDataFrame( new java.util.ArrayList[Row](), @@ -60,6 +63,11 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { new java.util.ArrayList[Row](), StructType(Seq(StructField("id", IntegerType), StructField("name", StringType)))) + lazy val sparkTestRelationMap: DataFrame = + spark.createDataFrame( + new java.util.ArrayList[Row](), + StructType(Seq(StructField("id", MapType(StringType, StringType))))) + lazy val localRelation = createLocalRelationProto(Seq(AttributeReference("id", IntegerType)())) test("Basic select") { @@ -140,10 +148,35 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectPlan2, sparkPlan2) } - test("column alias") { + test("SPARK-40809: column alias") { + // Simple Test. val connectPlan = connectTestRelation.select("id".protoAttr.as("id2")) val sparkPlan = sparkTestRelation.select(Column("id").alias("id2")) comparePlans(connectPlan, sparkPlan) + + // Scalar columns with metadata + val mdJson = "{\"max\": 99}" + comparePlans( + connectTestRelation.select("id".protoAttr.as("id2", mdJson)), + sparkTestRelation.select(Column("id").as("id2", Metadata.fromJson(mdJson)))) + + comparePlans( + connectTestRelationMap.select(proto_explode("id".protoAttr).as(Seq("a", "b"))), + sparkTestRelationMap.select(explode(Column("id")).as(Seq("a", "b")))) + + // Metadata must only be specified for regular Aliases. + assertThrows[InvalidPlanInput] { + val attr = proto_explode("id".protoAttr) + val alias = proto.Expression.Alias + .newBuilder() + .setExpr(attr) + .addName("a") + .addName("b") + .setMetadata(mdJson) + .build() + transform( + connectTestRelationMap.select(proto.Expression.newBuilder().setAlias(alias).build())) + } } test("Aggregate with more than 1 grouping expressions") { diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 417bc7097de3..9f610bf18fee 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -17,6 +17,7 @@ import uuid from typing import cast, get_args, TYPE_CHECKING, Callable, Any +import json import decimal import datetime @@ -82,6 +83,71 @@ def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": def __str__(self) -> str: ... + def alias(self, *alias: str, **kwargs: Any) -> "ColumnAlias": + """ + Returns this column aliased with a new name or names (in the case of expressions that + return more than one column, such as explode). + + .. versionadded:: 3.4.0 + + Parameters + ---------- + alias : str + desired column names (collects all positional arguments passed) + + Other Parameters + ---------------- + metadata: dict + a dict of information to be stored in ``metadata`` attribute of the + corresponding :class:`StructField ` (optional, keyword + only argument) + + Returns + ------- + :class:`Column` + Column representing whether each element of Column is aliased with new name or names. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.select(df.age.alias("age2")).collect() + [Row(age2=2), Row(age2=5)] + >>> df.select(df.age.alias("age3", metadata={'max': 99})).schema['age3'].metadata['max'] + 99 + """ + metadata = kwargs.pop("metadata", None) + assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs + return ColumnAlias(self, list(alias), metadata) + + +class ColumnAlias(Expression): + def __init__(self, parent: Expression, alias: list[str], metadata: Any): + + self._alias = alias + self._metadata = metadata + self._parent = parent + + def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": + if len(self._alias) == 1: + exp = proto.Expression() + exp.alias.name.append(self._alias[0]) + exp.alias.expr.CopyFrom(self._parent.to_plan(session)) + + if self._metadata: + exp.alias.metadata = json.dumps(self._metadata) + return exp + else: + if self._metadata: + raise ValueError("metadata can only be provided for a single column") + exp = proto.Expression() + exp.alias.name.extend(self._alias) + exp.alias.expr.CopyFrom(self._parent.to_plan(session)) + return exp + + def __str__(self) -> str: + return f"Alias({self._parent}, ({','.join(self._alias)}))" + class LiteralExpression(Expression): """A literal expression. diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index e2334b66c680..5dd28c0e6a94 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -120,7 +120,7 @@ def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "Dat new_frame._plan = plan return new_frame - def select(self, *cols: ColumnOrName) -> "DataFrame": + def select(self, *cols: "ExpressionOrString") -> "DataFrame": return DataFrame.withPlan(plan.Project(self._plan, *cols), session=self._session) def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame": diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 89718650571f..dca9d2cef47f 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xc2\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x1a\xa3\x10\n\x07Literal\x12\x1a\n\x07\x62oolean\x18\x01 \x01(\x08H\x00R\x07\x62oolean\x12\x10\n\x02i8\x18\x02 \x01(\x05H\x00R\x02i8\x12\x12\n\x03i16\x18\x03 \x01(\x05H\x00R\x03i16\x12\x12\n\x03i32\x18\x05 \x01(\x05H\x00R\x03i32\x12\x12\n\x03i64\x18\x07 \x01(\x03H\x00R\x03i64\x12\x14\n\x04\x66p32\x18\n \x01(\x02H\x00R\x04\x66p32\x12\x14\n\x04\x66p64\x18\x0b \x01(\x01H\x00R\x04\x66p64\x12\x18\n\x06string\x18\x0c \x01(\tH\x00R\x06string\x12\x18\n\x06\x62inary\x18\r \x01(\x0cH\x00R\x06\x62inary\x12\x1e\n\ttimestamp\x18\x0e \x01(\x03H\x00R\ttimestamp\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x14\n\x04time\x18\x11 \x01(\x03H\x00R\x04time\x12l\n\x16interval_year_to_month\x18\x13 \x01(\x0b\x32\x35.spark.connect.Expression.Literal.IntervalYearToMonthH\x00R\x13intervalYearToMonth\x12l\n\x16interval_day_to_second\x18\x14 \x01(\x0b\x32\x35.spark.connect.Expression.Literal.IntervalDayToSecondH\x00R\x13intervalDayToSecond\x12\x1f\n\nfixed_char\x18\x15 \x01(\tH\x00R\tfixedChar\x12\x46\n\x08var_char\x18\x16 \x01(\x0b\x32).spark.connect.Expression.Literal.VarCharH\x00R\x07varChar\x12#\n\x0c\x66ixed_binary\x18\x17 \x01(\x0cH\x00R\x0b\x66ixedBinary\x12\x45\n\x07\x64\x65\x63imal\x18\x18 \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x42\n\x06struct\x18\x19 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x1a \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12#\n\x0ctimestamp_tz\x18\x1b \x01(\x03H\x00R\x0btimestampTz\x12\x14\n\x04uuid\x18\x1c \x01(\x0cH\x00R\x04uuid\x12-\n\x04null\x18\x1d \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12<\n\x04list\x18\x1e \x01(\x0b\x32&.spark.connect.Expression.Literal.ListH\x00R\x04list\x12=\n\nempty_list\x18\x1f \x01(\x0b\x32\x1c.spark.connect.DataType.ListH\x00R\temptyList\x12:\n\tempty_map\x18 \x01(\x0b\x32\x1b.spark.connect.DataType.MapH\x00R\x08\x65mptyMap\x12R\n\x0cuser_defined\x18! \x01(\x0b\x32-.spark.connect.Expression.Literal.UserDefinedH\x00R\x0buserDefined\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1a\x37\n\x07VarChar\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12\x16\n\x06length\x18\x02 \x01(\rR\x06length\x1aS\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value\x12\x1c\n\tprecision\x18\x02 \x01(\x05R\tprecision\x12\x14\n\x05scale\x18\x03 \x01(\x05R\x05scale\x1a\xce\x01\n\x03Map\x12M\n\nkey_values\x18\x01 \x03(\x0b\x32..spark.connect.Expression.Literal.Map.KeyValueR\tkeyValues\x1ax\n\x08KeyValue\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value\x1a\x43\n\x13IntervalYearToMonth\x12\x14\n\x05years\x18\x01 \x01(\x05R\x05years\x12\x16\n\x06months\x18\x02 \x01(\x05R\x06months\x1ag\n\x13IntervalDayToSecond\x12\x12\n\x04\x64\x61ys\x18\x01 \x01(\x05R\x04\x64\x61ys\x12\x18\n\x07seconds\x18\x02 \x01(\x05R\x07seconds\x12"\n\x0cmicroseconds\x18\x03 \x01(\x05R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x41\n\x04List\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a`\n\x0bUserDefined\x12%\n\x0etype_reference\x18\x01 \x01(\rR\rtypeReference\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.AnyR\x05valueB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\x63\n\x12UnresolvedFunction\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05parts\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a\x10\n\x0eUnresolvedStar\x1aU\n\x12QualifiedAttribute\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12+\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x04type\x1aJ\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x01(\tR\x04nameB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xf0\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x1a\xa3\x10\n\x07Literal\x12\x1a\n\x07\x62oolean\x18\x01 \x01(\x08H\x00R\x07\x62oolean\x12\x10\n\x02i8\x18\x02 \x01(\x05H\x00R\x02i8\x12\x12\n\x03i16\x18\x03 \x01(\x05H\x00R\x03i16\x12\x12\n\x03i32\x18\x05 \x01(\x05H\x00R\x03i32\x12\x12\n\x03i64\x18\x07 \x01(\x03H\x00R\x03i64\x12\x14\n\x04\x66p32\x18\n \x01(\x02H\x00R\x04\x66p32\x12\x14\n\x04\x66p64\x18\x0b \x01(\x01H\x00R\x04\x66p64\x12\x18\n\x06string\x18\x0c \x01(\tH\x00R\x06string\x12\x18\n\x06\x62inary\x18\r \x01(\x0cH\x00R\x06\x62inary\x12\x1e\n\ttimestamp\x18\x0e \x01(\x03H\x00R\ttimestamp\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x14\n\x04time\x18\x11 \x01(\x03H\x00R\x04time\x12l\n\x16interval_year_to_month\x18\x13 \x01(\x0b\x32\x35.spark.connect.Expression.Literal.IntervalYearToMonthH\x00R\x13intervalYearToMonth\x12l\n\x16interval_day_to_second\x18\x14 \x01(\x0b\x32\x35.spark.connect.Expression.Literal.IntervalDayToSecondH\x00R\x13intervalDayToSecond\x12\x1f\n\nfixed_char\x18\x15 \x01(\tH\x00R\tfixedChar\x12\x46\n\x08var_char\x18\x16 \x01(\x0b\x32).spark.connect.Expression.Literal.VarCharH\x00R\x07varChar\x12#\n\x0c\x66ixed_binary\x18\x17 \x01(\x0cH\x00R\x0b\x66ixedBinary\x12\x45\n\x07\x64\x65\x63imal\x18\x18 \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x42\n\x06struct\x18\x19 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x1a \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12#\n\x0ctimestamp_tz\x18\x1b \x01(\x03H\x00R\x0btimestampTz\x12\x14\n\x04uuid\x18\x1c \x01(\x0cH\x00R\x04uuid\x12-\n\x04null\x18\x1d \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12<\n\x04list\x18\x1e \x01(\x0b\x32&.spark.connect.Expression.Literal.ListH\x00R\x04list\x12=\n\nempty_list\x18\x1f \x01(\x0b\x32\x1c.spark.connect.DataType.ListH\x00R\temptyList\x12:\n\tempty_map\x18 \x01(\x0b\x32\x1b.spark.connect.DataType.MapH\x00R\x08\x65mptyMap\x12R\n\x0cuser_defined\x18! \x01(\x0b\x32-.spark.connect.Expression.Literal.UserDefinedH\x00R\x0buserDefined\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1a\x37\n\x07VarChar\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12\x16\n\x06length\x18\x02 \x01(\rR\x06length\x1aS\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value\x12\x1c\n\tprecision\x18\x02 \x01(\x05R\tprecision\x12\x14\n\x05scale\x18\x03 \x01(\x05R\x05scale\x1a\xce\x01\n\x03Map\x12M\n\nkey_values\x18\x01 \x03(\x0b\x32..spark.connect.Expression.Literal.Map.KeyValueR\tkeyValues\x1ax\n\x08KeyValue\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value\x1a\x43\n\x13IntervalYearToMonth\x12\x14\n\x05years\x18\x01 \x01(\x05R\x05years\x12\x16\n\x06months\x18\x02 \x01(\x05R\x06months\x1ag\n\x13IntervalDayToSecond\x12\x12\n\x04\x64\x61ys\x18\x01 \x01(\x05R\x04\x64\x61ys\x12\x18\n\x07seconds\x18\x02 \x01(\x05R\x07seconds\x12"\n\x0cmicroseconds\x18\x03 \x01(\x05R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x41\n\x04List\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a`\n\x0bUserDefined\x12%\n\x0etype_reference\x18\x01 \x01(\rR\rtypeReference\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.AnyR\x05valueB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\x63\n\x12UnresolvedFunction\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05parts\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a\x10\n\x0eUnresolvedStar\x1aU\n\x12QualifiedAttribute\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12+\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x04type\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadataB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -43,7 +43,7 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 3115 + _EXPRESSION._serialized_end = 3161 _EXPRESSION_LITERAL._serialized_start = 613 _EXPRESSION_LITERAL._serialized_end = 2696 _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1923 @@ -75,5 +75,5 @@ _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_start = 2941 _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_end = 3026 _EXPRESSION_ALIAS._serialized_start = 3028 - _EXPRESSION_ALIAS._serialized_end = 3102 + _EXPRESSION_ALIAS._serialized_end = 3148 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index ffaf03eade88..ea538b2ebec7 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -633,21 +633,37 @@ class Expression(google.protobuf.message.Message): EXPR_FIELD_NUMBER: builtins.int NAME_FIELD_NUMBER: builtins.int + METADATA_FIELD_NUMBER: builtins.int @property def expr(self) -> global___Expression: ... - name: builtins.str + @property + def name( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... + metadata: builtins.str + """Alias metadata expressed as a JSON map.""" def __init__( self, *, expr: global___Expression | None = ..., - name: builtins.str = ..., + name: collections.abc.Iterable[builtins.str] | None = ..., + metadata: builtins.str | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["expr", b"expr"] + self, + field_name: typing_extensions.Literal[ + "_metadata", b"_metadata", "expr", b"expr", "metadata", b"metadata" + ], ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["expr", b"expr", "name", b"name"] + self, + field_name: typing_extensions.Literal[ + "_metadata", b"_metadata", "expr", b"expr", "metadata", b"metadata", "name", b"name" + ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"] + ) -> typing_extensions.Literal["metadata"] | None: ... LITERAL_FIELD_NUMBER: builtins.int UNRESOLVED_ATTRIBUTE_FIELD_NUMBER: builtins.int diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 90a534d599e2..a49829cc0857 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -33,14 +33,15 @@ if have_pandas: from pyspark.sql.connect.client import RemoteSparkSession, ChannelBuilder from pyspark.sql.connect.function_builder import udf - from pyspark.sql.connect.functions import lit + from pyspark.sql.connect.functions import lit, col from pyspark.sql.dataframe import DataFrame from pyspark.testing.connectutils import should_test_connect, connect_requirement_message +from pyspark.testing.pandasutils import PandasOnSparkTestCase from pyspark.testing.utils import ReusedPySparkTestCase @unittest.skipIf(not should_test_connect, connect_requirement_message) -class SparkConnectSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): +class SparkConnectSQLTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase, SQLTestUtils): """Parent test fixture class for all Spark Connect related test cases.""" @@ -192,20 +193,17 @@ def test_subquery_alias(self) -> None: self.assertTrue("special_alias" in plan_text) def test_range(self): - self.assertTrue( - self.connect.range(start=0, end=10) - .toPandas() - .equals(self.spark.range(start=0, end=10).toPandas()) + self.assert_eq( + self.connect.range(start=0, end=10).toPandas(), + self.spark.range(start=0, end=10).toPandas(), ) - self.assertTrue( - self.connect.range(start=0, end=10, step=3) - .toPandas() - .equals(self.spark.range(start=0, end=10, step=3).toPandas()) + self.assert_eq( + self.connect.range(start=0, end=10, step=3).toPandas(), + self.spark.range(start=0, end=10, step=3).toPandas(), ) - self.assertTrue( - self.connect.range(start=0, end=10, step=3, numPartitions=2) - .toPandas() - .equals(self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas()) + self.assert_eq( + self.connect.range(start=0, end=10, step=3, numPartitions=2).toPandas(), + self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas(), ) def test_create_global_temp_view(self): @@ -235,29 +233,21 @@ def test_fill_na(self): # | null| 3| 3.0| # +-----+----+----+ - self.assertTrue( - self.connect.sql(query) - .fillna(True) - .toPandas() - .equals(self.spark.sql(query).fillna(True).toPandas()) + self.assert_eq( + self.connect.sql(query).fillna(True).toPandas(), + self.spark.sql(query).fillna(True).toPandas(), ) - self.assertTrue( - self.connect.sql(query) - .fillna(2) - .toPandas() - .equals(self.spark.sql(query).fillna(2).toPandas()) + self.assert_eq( + self.connect.sql(query).fillna(2).toPandas(), + self.spark.sql(query).fillna(2).toPandas(), ) - self.assertTrue( - self.connect.sql(query) - .fillna(2, ["a", "b"]) - .toPandas() - .equals(self.spark.sql(query).fillna(2, ["a", "b"]).toPandas()) + self.assert_eq( + self.connect.sql(query).fillna(2, ["a", "b"]).toPandas(), + self.spark.sql(query).fillna(2, ["a", "b"]).toPandas(), ) - self.assertTrue( - self.connect.sql(query) - .na.fill({"a": True, "b": 2}) - .toPandas() - .equals(self.spark.sql(query).na.fill({"a": True, "b": 2}).toPandas()) + self.assert_eq( + self.connect.sql(query).na.fill({"a": True, "b": 2}).toPandas(), + self.spark.sql(query).na.fill({"a": True, "b": 2}).toPandas(), ) def test_empty_dataset(self): @@ -301,6 +291,20 @@ def test_simple_datasource_read(self) -> None: actualResult = pandasResult.values.tolist() self.assertEqual(len(expectResult), len(actualResult)) + def test_alias(self) -> None: + """Testing supported and unsupported alias""" + col0 = ( + self.connect.range(1, 10) + .select(col("id").alias("name", metadata={"max": 99})) + .schema() + .names[0] + ) + self.assertEqual("name", col0) + + with self.assertRaises(grpc.RpcError) as exc: + self.connect.range(1, 10).select(col("id").alias("this", "is", "not")).collect() + self.assertIn("Buffer(this, is, not)", str(exc.exception)) + class ChannelBuilderTests(ReusedPySparkTestCase): def test_invalid_connection_strings(self): diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 59e3c97679e8..99b63482a243 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -134,6 +134,16 @@ def test_list_to_literal(self): lit_list_plan = fun.lit([fun.lit(10), fun.lit("str")]).to_plan(None) self.assertIsNotNone(lit_list_plan) + def test_column_alias(self) -> None: + # SPARK-40809: Support for Column Aliases + col0 = fun.col("a").alias("martin") + self.assertEqual("Alias(Column(a), (martin))", str(col0)) + + col0 = fun.col("a").alias("martin", metadata={"pii": True}) + plan = col0.to_plan(self.session) + self.assertIsNotNone(plan) + self.assertEqual(plan.alias.metadata, '{"pii": true}') + def test_column_expressions(self): """Test a more complex combination of expressions and their translation into the protobuf structure.""" diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index b7c49a6df545..f98a67b9964b 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -18,6 +18,7 @@ from typing import Any, Dict, Optional import functools import unittest + from pyspark.testing.sqlutils import have_pandas if have_pandas: @@ -25,6 +26,7 @@ from pyspark.sql.connect.plan import Read, Range, SQL from pyspark.testing.utils import search_jar from pyspark.sql.connect.plan import LogicalPlan + from pyspark.sql.connect.client import RemoteSparkSession connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") else: @@ -67,6 +69,7 @@ def __getattr__(self, item: str) -> Any: class PlanOnlyTestFixture(unittest.TestCase): connect: "MockRemoteSession" + session: RemoteSparkSession @classmethod def _read_table(cls, table_name: str) -> "DataFrame": @@ -99,6 +102,7 @@ def _with_plan(cls, plan: LogicalPlan) -> "DataFrame": @classmethod def setUpClass(cls: Any) -> None: cls.connect = MockRemoteSession() + cls.session = RemoteSparkSession() cls.tbl_name = "test_connect_plan_only_table_1" cls.connect.set_hook("register_udf", cls._udf_mock)