diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto index 791b1b5887b7..4b5a81d2a568 100644 --- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto @@ -35,6 +35,7 @@ message Expression { UnresolvedFunction unresolved_function = 3; ExpressionString expression_string = 4; UnresolvedStar unresolved_star = 5; + Alias alias = 6; } message Literal { @@ -166,4 +167,9 @@ message Expression { string name = 1; DataType type = 2; } + + message Alias { + Expression expr = 1; + string name = 2; + } } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 3ccf71c26b74..80d6e77c9fc4 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -40,6 +40,11 @@ package object dsl { .build()) .build() } + + implicit class DslExpression(val expr: proto.Expression) { + def as(alias: String): proto.Expression = proto.Expression.newBuilder().setAlias( + proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build() + } } object plans { // scalastyle:ignore diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 66560f5e62f6..5ad95a6b516a 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -24,7 +24,7 @@ import org.apache.spark.connect.proto import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.types._ @@ -132,6 +132,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { transformUnresolvedExpression(exp) case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION => transformScalarFunction(exp.getUnresolvedFunction) + case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias) case _ => throw InvalidPlanInput() } } @@ -208,6 +209,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } } + private def transformAlias(alias: proto.Expression.Alias): Expression = { + Alias(transformExpression(alias.getExpr), alias.getName)() + } + private def transformUnion(u: proto.Union): LogicalPlan = { assert(u.getInputsCount == 2, "Union must have 2 inputs") val plan = logical.Union(transformRelation(u.getInputs(0)), transformRelation(u.getInputs(1))) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 441a3a9f1e41..510b54cd2508 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -81,6 +81,15 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } } + test("column alias") { + val connectPlan = { + import org.apache.spark.sql.connect.dsl.expressions._ + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.select("id".protoAttr.as("id2"))) + } + val sparkPlan = sparkTestRelation.select($"id".as("id2")) + } + test("Aggregate with more than 1 grouping expressions") { val connectPlan = { import org.apache.spark.sql.connect.dsl.expressions._