Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ message Relation {
LocalRelation local_relation = 11;
Sample sample = 12;
Offset offset = 13;
Deduplicate deduplicate = 14;

Unknown unknown = 999;
}
Expand Down Expand Up @@ -181,6 +182,14 @@ message Sort {
}
}

// Relation of type [[Deduplicate]] which have duplicate rows removed, could consider either only
// the subset of columns or all the columns.
message Deduplicate {
Relation input = 1;
repeated string column_names = 2;
bool all_columns_as_keys = 3;
}

message LocalRelation {
repeated Expression.QualifiedAttribute attributes = 1;
// TODO: support local data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,26 @@ package object dsl {
.build()
}

def deduplicate(colNames: Seq[String]): proto.Relation =
proto.Relation
.newBuilder()
.setDeduplicate(
proto.Deduplicate
.newBuilder()
.setInput(logicalPlan)
.addAllColumnNames(colNames.asJava))
.build()

def distinct(): proto.Relation =
proto.Relation
.newBuilder()
.setDeduplicate(
proto.Deduplicate
.newBuilder()
.setInput(logicalPlan)
.setAllColumnsAsKeys(true))
.build()

def join(
otherPlan: proto.Relation,
joinType: JoinType = JoinType.JOIN_TYPE_INNER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sample, SubqueryAlias}
import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, LogicalPlan, Sample, SubqueryAlias}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types._

final case class InvalidPlanInput(
Expand Down Expand Up @@ -60,6 +61,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset)
case proto.Relation.RelTypeCase.JOIN => transformJoin(rel.getJoin)
case proto.Relation.RelTypeCase.UNION => transformUnion(rel.getUnion)
case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate)
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate)
case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
Expand Down Expand Up @@ -91,6 +93,37 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
transformRelation(rel.getInput))
}

private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
}
if (rel.getAllColumnsAsKeys && rel.getColumnNamesCount > 0) {
throw InvalidPlanInput("Cannot deduplicate on both all columns and a subset of columns")
}
if (!rel.getAllColumnsAsKeys && rel.getColumnNamesCount == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need getAllColumnsAsKeys? Seems like we can just tell when the columns are not set.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this is not matched with the logical plan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is, in Spark connect client, we only see column names, not expr IDs. If the DF has duplicated column names, then deduplicate by all columns can't work in Spark connect client.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, column name is unknown either, as the input plan is unresolved.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, we don't need rel.getAllColumnsAsKeys condition because we can know that's the case when rel.getColumnNamesCount == 0.

Copy link
Contributor Author

@amaliujia amaliujia Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, we don't need rel.getAllColumnsAsKeys condition because we can know that's the case when rel.getColumnNamesCount == 0.

I want to clarify this specifically:

This is one of the Connect proto API design principle: we need to differentiate if a field is set or not set explicitly, or put it in another way, every intention should be expressed explicitly. Ultimately, this is to avoid ambiguity on the API surface.

One example is Project. If we see a Project without anything in the project list, then how do we interpret that? Does the user want to indicate a SELECT *? Does the user actually generate an invalid plan. The problem now is there are two possibilities for a plan, and the worse part is, one possibility is a valid plan, another is not. This led us explicitly encode SELECT * into the proto #38023.

So one of the reasons that we have a bool flag here is to not use rel.getColumnNamesCount == 0 to infer distinct on all columns which has caused ambiguity problem.

This might not be great because a few more fields could bring another problem: what if the user set them all. In terms of ambiguity, this is not an issue: we know that is an invalid plan without second choice.

throw InvalidPlanInput(
"Deduplicate requires to either deduplicate on all columns or a subset of columns")
}
val queryExecution = new QueryExecution(session, transformRelation(rel.getInput))
val resolver = session.sessionState.analyzer.resolver
val allColumns = queryExecution.analyzed.output
if (rel.getAllColumnsAsKeys) {
Deduplicate(allColumns, queryExecution.analyzed)
} else {
val toGroupColumnNames = rel.getColumnNamesList.asScala.toSeq
val groupCols = toGroupColumnNames.flatMap { (colName: String) =>
// It is possibly there are more than one columns with the same name,
// so we call filter instead of find.
val cols = allColumns.filter(col => resolver(col.name, colName))
if (cols.isEmpty) {
throw InvalidPlanInput(s"Invalid deduplicate column ${colName}")
}
cols
}
Deduplicate(groupCols, queryExecution.analyzed)
}
}

private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq
new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.planner

import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

/**
* [[SparkConnectPlanTestWithSparkSession]] contains a SparkSession for the connect planner.
*
* It is not recommended to use Catalyst DSL along with this trait because `SharedSparkSession`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me think that we probably shouldn't use catalyst DSL at all. The tests need Spark Connect planner, which needs SparkSession. Some tests happen to not invoke SparkSession, but it's a bit hacky to rely on this assumption. We should just compare plans produced by proto DSL and DataFrame APIs.

Copy link
Contributor Author

@amaliujia amaliujia Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided to with Catalyst DSL in #37994. However that also caused the pain of does the small scope to avoid implicit conflicts.

Given that we need a session based test, migrate all such tests to the same place that uses session and also DataFrame API just makes sense. Also I believe in this approach we don't have implicit conflict anymore in the same scope.

How about after this PR, let me send a follow up PR for the testing refactoring?

Or we can the have refactoring happen (I prefer a separate PR) then I can rebase this one. Either way is fine to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it in a followup.

* has also defined implicits over Catalyst LogicalPlan which will cause ambiguity with the
* implicits defined in Catalyst DSL.
*/
trait SparkConnectPlanTestWithSparkSession extends SharedSparkSession with SparkConnectPlanTest {
override def getSession(): SparkSession = spark
}

class SparkConnectDeduplicateSuite extends SparkConnectPlanTestWithSparkSession {
lazy val connectTestRelation = createLocalRelationProto(
Seq(
AttributeReference("id", IntegerType)(),
AttributeReference("key", StringType)(),
AttributeReference("value", StringType)()))

lazy val sparkTestRelation = {
spark.createDataFrame(
new java.util.ArrayList[Row](),
StructType(
Seq(
StructField("id", IntegerType),
StructField("key", StringType),
StructField("value", StringType))))
}

test("Test basic deduplicate") {
val connectPlan = {
import org.apache.spark.sql.connect.dsl.plans._
Dataset.ofRows(spark, transform(connectTestRelation.distinct()))
}

val sparkPlan = sparkTestRelation.distinct()
comparePlans(connectPlan.queryExecution.analyzed, sparkPlan.queryExecution.analyzed, false)

val connectPlan2 = {
import org.apache.spark.sql.connect.dsl.plans._
Dataset.ofRows(spark, transform(connectTestRelation.deduplicate(Seq("key", "value"))))
}
Comment on lines +53 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
val connectPlan = {
import org.apache.spark.sql.connect.dsl.plans._
Dataset.ofRows(spark, transform(connectTestRelation.distinct()))
}
val sparkPlan = sparkTestRelation.distinct()
comparePlans(connectPlan.queryExecution.analyzed, sparkPlan.queryExecution.analyzed, false)
val connectPlan2 = {
import org.apache.spark.sql.connect.dsl.plans._
Dataset.ofRows(spark, transform(connectTestRelation.deduplicate(Seq("key", "value"))))
}
import org.apache.spark.sql.connect.dsl.plans._
val connectPlan = Dataset.ofRows(spark, transform(connectTestRelation.distinct()))
val sparkPlan = sparkTestRelation.distinct()
comparePlans(connectPlan.queryExecution.analyzed, sparkPlan.queryExecution.analyzed, false)
val connectPlan2 = Dataset.ofRows(spark, transform(connectTestRelation.deduplicate(Seq("key", "value"))))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there was an issue here with the way that the two implicits of Spark and Spark Connect DSL are handled.

Copy link
Contributor Author

@amaliujia amaliujia Oct 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah here are some context that people may not know:

Scala seems to not allow two implicit defined in the same scope even though there is no ambiguity. In this case, Scala chooses to ignore one of the implementation. The workaround was to use sub-scope to limit one implicit (which is for connect) then its parent scope imports another implicit.

See comment here for the context:

// TODO: Scala only allows one implicit per scope so we keep proto implicit imports in

val sparkPlan2 = sparkTestRelation.dropDuplicates(Seq("key", "value"))
comparePlans(connectPlan2.queryExecution.analyzed, sparkPlan2.queryExecution.analyzed, false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
* test cases.
*/
trait SparkConnectPlanTest {

def getSession(): SparkSession = None.orNull
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to reuse SharedSparkSessionBase?

Copy link
Contributor Author

@amaliujia amaliujia Oct 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is relevant to the comment above: the SharedSparkSession and its base defines implicit which will cause ambiguity with Catalyst implicit. For current SparkConnectProtoSuite, we cannot let it inherit SharedSparkSession. Meanwhile for the testing purpose on this Deduplicate implementation, we need a session. This is why this PR does some refactoring on the testing suites to have a separation.


def transform(rel: proto.Relation): LogicalPlan = {
new SparkConnectPlanner(rel, None.orNull).transform()
new SparkConnectPlanner(rel, getSession()).transform()
}

def readRel: proto.Relation =
Expand Down Expand Up @@ -72,8 +75,6 @@ trait SparkConnectPlanTest {
*/
class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {

protected var spark: SparkSession = null

test("Simple Limit") {
assertThrows[IndexOutOfBoundsException] {
new SparkConnectPlanner(
Expand Down Expand Up @@ -266,4 +267,26 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
.build()))
assert(e.getMessage.contains("DataSource requires a format"))
}

test("Test invalid deduplicate") {
val deduplicate = proto.Deduplicate
.newBuilder()
.setInput(readRel)
.setAllColumnsAsKeys(true)
.addColumnNames("test")

val e = intercept[InvalidPlanInput] {
transform(proto.Relation.newBuilder.setDeduplicate(deduplicate).build())
}
assert(
e.getMessage.contains("Cannot deduplicate on both all columns and a subset of columns"))

val deduplicate2 = proto.Deduplicate
.newBuilder()
.setInput(readRel)
val e2 = intercept[InvalidPlanInput] {
transform(proto.Relation.newBuilder.setDeduplicate(deduplicate2).build())
}
assert(e2.getMessage.contains("either deduplicate on all columns or a subset of columns"))
}
}
Loading