diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 9e3899f4a1a0..c40afdee6524 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -67,7 +67,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) case proto.Relation.RelTypeCase.LOCAL_RELATION => - transformLocalRelation(rel.getLocalRelation) + transformLocalRelation(rel.getLocalRelation, common) case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") @@ -125,9 +125,16 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } } - private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { + private def transformLocalRelation( + rel: proto.LocalRelation, + common: Option[proto.RelationCommon]): LogicalPlan = { val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq - new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes) + val relation = new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes) + if (common.nonEmpty && common.get.getAlias.nonEmpty) { + logical.SubqueryAlias(identifier = common.get.getAlias, child = relation) + } else { + relation + } } private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectDeduplicateSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectDeduplicateSuite.scala deleted file mode 100644 index 88af60581ba2..000000000000 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectDeduplicateSuite.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.connect.planner - -import org.apache.spark.sql.{Dataset, Row, SparkSession} -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} - -/** - * [[SparkConnectPlanTestWithSparkSession]] contains a SparkSession for the connect planner. - * - * It is not recommended to use Catalyst DSL along with this trait because `SharedSparkSession` - * has also defined implicits over Catalyst LogicalPlan which will cause ambiguity with the - * implicits defined in Catalyst DSL. - */ -trait SparkConnectPlanTestWithSparkSession extends SharedSparkSession with SparkConnectPlanTest { - override def getSession(): SparkSession = spark -} - -class SparkConnectDeduplicateSuite extends SparkConnectPlanTestWithSparkSession { - lazy val connectTestRelation = createLocalRelationProto( - Seq( - AttributeReference("id", IntegerType)(), - AttributeReference("key", StringType)(), - AttributeReference("value", StringType)())) - - lazy val sparkTestRelation = { - spark.createDataFrame( - new java.util.ArrayList[Row](), - StructType( - Seq( - StructField("id", IntegerType), - StructField("key", StringType), - StructField("value", StringType)))) - } - - test("Test basic deduplicate") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.plans._ - Dataset.ofRows(spark, transform(connectTestRelation.distinct())) - } - - val sparkPlan = sparkTestRelation.distinct() - comparePlans(connectPlan.queryExecution.analyzed, sparkPlan.queryExecution.analyzed, false) - - val connectPlan2 = { - import org.apache.spark.sql.connect.dsl.plans._ - Dataset.ofRows(spark, transform(connectTestRelation.deduplicate(Seq("key", "value")))) - } - val sparkPlan2 = sparkTestRelation.dropDuplicates(Seq("key", "value")) - comparePlans(connectPlan2.queryExecution.analyzed, sparkPlan2.queryExecution.analyzed, false) - } -} diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 6fc47e07c598..49072982c00c 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -22,20 +22,18 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Expression.UnresolvedStar -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.test.SharedSparkSession /** * Testing trait for SparkConnect tests with some helper methods to make it easier to create new * test cases. */ -trait SparkConnectPlanTest { - - def getSession(): SparkSession = None.orNull +trait SparkConnectPlanTest extends SharedSparkSession { def transform(rel: proto.Relation): LogicalPlan = { - new SparkConnectPlanner(rel, getSession()).transform() + new SparkConnectPlanner(rel, spark).transform() } def readRel: proto.Relation = diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 0325b6573bd3..a38b1951eb23 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -18,10 +18,16 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter, UsingJoin} +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connect.dsl.expressions._ +import org.apache.spark.sql.connect.dsl.plans._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} /** * This suite is based on connect DSL and test that given same dataframe operations, whether @@ -30,81 +36,61 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation */ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { - lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, $"name".string)) + lazy val connectTestRelation = + createLocalRelationProto( + Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)())) - lazy val connectTestRelation2 = createLocalRelationProto( - Seq($"key".int, $"value".int, $"name".string)) + lazy val connectTestRelation2 = + createLocalRelationProto( + Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)())) - lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, $"name".string) + lazy val sparkTestRelation: DataFrame = + spark.createDataFrame( + new java.util.ArrayList[Row](), + StructType(Seq(StructField("id", IntegerType), StructField("name", StringType)))) - lazy val sparkTestRelation2: LocalRelation = - LocalRelation($"key".int, $"value".int, $"name".string) + lazy val sparkTestRelation2: DataFrame = + spark.createDataFrame( + new java.util.ArrayList[Row](), + StructType(Seq(StructField("id", IntegerType), StructField("name", StringType)))) test("Basic select") { - val connectPlan = { - // TODO: Scala only allows one implicit per scope so we keep proto implicit imports in - // this scope. Need to find a better way to make two implicits work in the same scope. - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.select("id".protoAttr)) - } - val sparkPlan = sparkTestRelation.select($"id") - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + val connectPlan = connectTestRelation.select("id".protoAttr) + val sparkPlan = sparkTestRelation.select("id") + comparePlans(connectPlan, sparkPlan) } test("UnresolvedFunction resolution.") { - { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - assertThrows[IllegalArgumentException] { - transform(connectTestRelation.select(callFunction("default.hex", Seq("id".protoAttr)))) - } + assertThrows[IllegalArgumentException] { + transform(connectTestRelation.select(callFunction("default.hex", Seq("id".protoAttr)))) } - val connectPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform( - connectTestRelation.select(callFunction(Seq("default", "hex"), Seq("id".protoAttr)))) - } + val connectPlan = + connectTestRelation.select(callFunction(Seq("default", "hex"), Seq("id".protoAttr))) assertThrows[UnsupportedOperationException] { - connectPlan.analyze + analyzePlan(transform(connectPlan)) } - val validPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.select(callFunction(Seq("hex"), Seq("id".protoAttr)))) - } - assert(validPlan.analyze != null) + val validPlan = connectTestRelation.select(callFunction(Seq("hex"), Seq("id".protoAttr))) + assert(analyzePlan(transform(validPlan)) != null) } test("Basic filter") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.where("id".protoAttr < 0)) - } - - val sparkPlan = sparkTestRelation.where($"id" < 0).analyze - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + val connectPlan = connectTestRelation.where("id".protoAttr < 0) + val sparkPlan = sparkTestRelation.where(Column("id") < 0) + comparePlans(connectPlan, sparkPlan) } test("Basic joins with different join types") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.join(connectTestRelation2)) - } + val connectPlan = connectTestRelation.join(connectTestRelation2) val sparkPlan = sparkTestRelation.join(sparkTestRelation2) - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + comparePlans(connectPlan, sparkPlan) + + val connectPlan2 = connectTestRelation.join(connectTestRelation2) + val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2) + comparePlans(connectPlan2, sparkPlan2) - val connectPlan2 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.join(connectTestRelation2, condition = None)) - } - val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2, condition = None) - comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false) for ((t, y) <- Seq( (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter), (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter), @@ -112,99 +98,79 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti), (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi), (JoinType.JOIN_TYPE_INNER, Inner))) { - val connectPlan3 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.join(connectTestRelation2, t)) - } - val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, y) - comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false) - } - val connectPlan4 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform( - connectTestRelation.join(connectTestRelation2, JoinType.JOIN_TYPE_INNER, Seq("name"))) + val connectPlan3 = connectTestRelation.join(connectTestRelation2, t, Seq("id")) + val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, Seq("id"), y.toString) + comparePlans(connectPlan3, sparkPlan3) } - val sparkPlan4 = sparkTestRelation.join(sparkTestRelation2, UsingJoin(Inner, Seq("name"))) - comparePlans(connectPlan4.analyze, sparkPlan4.analyze, false) + + val connectPlan4 = + connectTestRelation.join(connectTestRelation2, JoinType.JOIN_TYPE_INNER, Seq("name")) + val sparkPlan4 = sparkTestRelation.join(sparkTestRelation2, Seq("name"), Inner.toString) + comparePlans(connectPlan4, sparkPlan4) } test("Test sample") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.sample(0, 0.2, false, 1)) - } - val sparkPlan = sparkTestRelation.sample(0, 0.2, false, 1) - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + val connectPlan = connectTestRelation.sample(0, 0.2, false, 1) + val sparkPlan = sparkTestRelation.sample(false, 0.2 - 0, 1) + comparePlans(connectPlan, sparkPlan) } test("column alias") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.select("id".protoAttr.as("id2"))) - } - val sparkPlan = sparkTestRelation.select($"id".as("id2")) - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + val connectPlan = connectTestRelation.select("id".protoAttr.as("id2")) + val sparkPlan = sparkTestRelation.select(Column("id").alias("id2")) + comparePlans(connectPlan, sparkPlan) } test("Aggregate with more than 1 grouping expressions") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.groupBy("id".protoAttr, "name".protoAttr)()) + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + val connectPlan = + connectTestRelation.groupBy("id".protoAttr, "name".protoAttr)() + val sparkPlan = + sparkTestRelation.groupBy(Column("id"), Column("name")).agg(Map.empty[String, String]) + comparePlans(connectPlan, sparkPlan) } - val sparkPlan = sparkTestRelation.groupBy($"id", $"name")() - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) } test("Test as(alias: String)") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.as("target_table")) - } - + val connectPlan = connectTestRelation.as("target_table") val sparkPlan = sparkTestRelation.as("target_table") - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + comparePlans(connectPlan, sparkPlan) } test("Test StructType in LocalRelation") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - transform(createLocalRelationProtoByQualifiedAttributes(Seq("a".struct("id".int)))) - } - val sparkPlan = LocalRelation($"a".struct($"id".int)) - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + val connectPlan = createLocalRelationProtoByQualifiedAttributes(Seq("a".struct("id".int))) + val sparkPlan = + LocalRelation(AttributeReference("a", StructType(Seq(StructField("id", IntegerType))))()) + comparePlans(connectPlan, sparkPlan) } test("Test limit offset") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.limit(10)) - } + val connectPlan = connectTestRelation.limit(10) val sparkPlan = sparkTestRelation.limit(10) - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + comparePlans(connectPlan, sparkPlan) - val connectPlan2 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.offset(2)) - } + val connectPlan2 = connectTestRelation.offset(2) val sparkPlan2 = sparkTestRelation.offset(2) - comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false) + comparePlans(connectPlan2, sparkPlan2) - val connectPlan3 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.limit(10).offset(2)) - } + val connectPlan3 = connectTestRelation.limit(10).offset(2) val sparkPlan3 = sparkTestRelation.limit(10).offset(2) - comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false) + comparePlans(connectPlan3, sparkPlan3) - val connectPlan4 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.offset(2).limit(10)) - } + val connectPlan4 = connectTestRelation.offset(2).limit(10) val sparkPlan4 = sparkTestRelation.offset(2).limit(10) - comparePlans(connectPlan4.analyze, sparkPlan4.analyze, false) + comparePlans(connectPlan4, sparkPlan4) + } + + test("Test basic deduplicate") { + val connectPlan = connectTestRelation.distinct() + val sparkPlan = sparkTestRelation.distinct() + comparePlans(connectPlan, sparkPlan) + + val connectPlan2 = connectTestRelation.deduplicate(Seq("id", "name")) + val sparkPlan2 = sparkTestRelation.dropDuplicates(Seq("id", "name")) + comparePlans(connectPlan2, sparkPlan2) } private def createLocalRelationProtoByQualifiedAttributes( @@ -215,4 +181,17 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() } + + // This is a function for testing only. This is used when the plan is ready and it only waits + // analyzer to analyze attribute references within the plan. + private def analyzePlan(plan: LogicalPlan): LogicalPlan = { + val connectAnalyzed = analysis.SimpleAnalyzer.execute(plan) + analysis.SimpleAnalyzer.checkAnalysis(connectAnalyzed) + connectAnalyzed + } + + private def comparePlans(connectPlan: proto.Relation, sparkPlan: DataFrame): Unit = { + val connectAnalyzed = analyzePlan(transform(connectPlan)) + comparePlans(connectAnalyzed, sparkPlan.queryExecution.analyzed, false) + } }