Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions connector/connect/src/main/protobuf/spark/connect/relations.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ message Relation {
RenameColumnsBySameLengthNames rename_columns_by_same_length_names = 18;
RenameColumnsByNameToNameMap rename_columns_by_name_to_name_map = 19;
ShowString show_string = 20;
Drop drop = 21;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -252,6 +253,19 @@ message Sort {
}
}


// Drop specified columns.
message Drop {
// (Required) The input relation.
Relation input = 1;

// (Required) columns to drop.
//
// Should contain at least 1 item.
repeated Expression cols = 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does drop actually support arbitrary expressions? Shouldn't this be a repeated unresolved attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dataset.drop takes arbitrary expressions into account

val expressions = (for (col <- allColumns) yield col match {
case Column(u: UnresolvedAttribute) =>
queryExecution.analyzed.resolveQuoted(
u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
case Column(expr: Expression) => expr
})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if the name should be more explicit like "dropped"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here follows the naming in def drop(col: Column, cols: Column*): DataFrame

}


// Relation of type [[Deduplicate]] which have duplicate rows removed, could consider either only
// the subset of columns or all the columns.
message Deduplicate {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,28 @@ package object dsl {
.build()
}

def drop(columns: String*): Relation = {
assert(columns.nonEmpty)

val cols = columns.map(col =>
Expression.newBuilder
.setUnresolvedAttribute(
Expression.UnresolvedAttribute.newBuilder
.setUnparsedIdentifier(col)
.build())
.build())

Relation
.newBuilder()
.setDrop(
Drop
.newBuilder()
.setInput(logicalPlan)
.addAllCols(cols.asJava)
.build())
.build()
}

def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): Relation = {
val agg = Aggregate.newBuilder()
agg.setInput(logicalPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.google.common.collect.{Lists, Maps}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.WriteOperation
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.{Column, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
Expand Down Expand Up @@ -69,6 +69,7 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate)
case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp)
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop)
case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate)
case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
case proto.Relation.RelTypeCase.LOCAL_RELATION =>
Expand Down Expand Up @@ -523,6 +524,19 @@ class SparkConnectPlanner(session: SparkSession) {
sameOrderExpressions = Seq.empty)
}

private def transformDrop(rel: proto.Drop): LogicalPlan = {
assert(rel.getColsCount > 0, s"cols must contains at least 1 item!")

val cols = rel.getColsList.asScala.toArray.map { expr =>
Column(transformExpression(expr))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should verify supported types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean verify for the arrow-based collect?
since we will remove the json code path, it always fails if there are unsupported types.

}

Dataset
.ofRows(session, transformRelation(rel.getInput))
.drop(cols.head, cols.tail: _*)
.logicalPlan
}

private def transformAggregate(rel: proto.Aggregate): LogicalPlan = {
assert(rel.hasInput)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,23 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan2, sparkPlan2)
}

test("SPARK-41169: Test drop") {
// single column
val connectPlan = connectTestRelation.drop("id")
val sparkPlan = sparkTestRelation.drop("id")
comparePlans(connectPlan, sparkPlan)

// all columns
val connectPlan2 = connectTestRelation.drop("id", "name")
val sparkPlan2 = sparkTestRelation.drop("id", "name")
comparePlans(connectPlan2, sparkPlan2)

// non-existing column
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you treat the dropped columns as expressions we need to add a negative test for unsupported expressions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked the implementation of Dataset.drop, it supports all kinds of expressions, a expression will be just ignored if it doesn't semanticEquals the columns in current dataframe.

val connectPlan3 = connectTestRelation.drop("id2", "name")
val sparkPlan3 = sparkTestRelation.drop("id2", "name")
comparePlans(connectPlan3, sparkPlan3)
}

test("SPARK-40809: column alias") {
// Simple Test.
val connectPlan = connectTestRelation.select("id".protoAttr.as("id2"))
Expand Down
19 changes: 15 additions & 4 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,21 @@ def distinct(self) -> "DataFrame":
)

def drop(self, *cols: "ColumnOrString") -> "DataFrame":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting case where one could argue for implementing the behavior on the client side instead of the server.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, there will be two RPC if we implement it on the client side
1, all_cols = self.columns to fetch the schema;
2, build the plan

with a dedicated proto mesage, we only need one RPC.

Copy link
Contributor

@amaliujia amaliujia Nov 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have built the consensus that we prefer re-using the proto than ask clients do duplicate work.

all_cols = self.columns
dropped = set([c.name() if isinstance(c, Column) else self[c].name() for c in cols])
dropped_cols = filter(lambda x: x in dropped, all_cols)
return DataFrame.withPlan(plan.Project(self._plan, *dropped_cols), session=self._session)
_cols = list(cols)
if any(not isinstance(c, (str, Column)) for c in _cols):
raise TypeError(
f"'cols' must contains strings or Columns, but got {type(cols).__name__}"
)
if len(_cols) == 0:
raise ValueError("'cols' must be non-empty")

return DataFrame.withPlan(
plan.Drop(
child=self._plan,
columns=_cols,
),
session=self._session,
)

def filter(self, condition: Expression) -> "DataFrame":
return DataFrame.withPlan(
Expand Down
43 changes: 43 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,49 @@ def _repr_html_(self) -> str:
"""


class Drop(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
columns: List[Union[Column, str]],
) -> None:
super().__init__(child)
assert len(columns) > 0 and all(isinstance(c, (Column, str)) for c in columns)
self.columns = columns

def _convert_to_expr(
self, col: Union[Column, str], session: "RemoteSparkSession"
) -> proto.Expression:
expr = proto.Expression()
if isinstance(col, Column):
expr.CopyFrom(col.to_plan(session))
else:
expr.CopyFrom(self.unresolved_attr(col))
return expr

def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None
plan = proto.Relation()
plan.drop.input.CopyFrom(self._child.plan(session))
plan.drop.cols.extend([self._convert_to_expr(c, session) for c in self.columns])
return plan

def print(self, indent: int = 0) -> str:
c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else ""
return f"{' ' * indent}<Drop columns={self.columns}>\n{c_buf}"

def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b>Drop</b><br />
columns: {self.columns} <br />
{self._child_repr_()}
</li>
</uL>
"""


class Sample(LogicalPlan):
def __init__(
self,
Expand Down
150 changes: 82 additions & 68 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Relation(google.protobuf.message.Message):
RENAME_COLUMNS_BY_SAME_LENGTH_NAMES_FIELD_NUMBER: builtins.int
RENAME_COLUMNS_BY_NAME_TO_NAME_MAP_FIELD_NUMBER: builtins.int
SHOW_STRING_FIELD_NUMBER: builtins.int
DROP_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
SUMMARY_FIELD_NUMBER: builtins.int
CROSSTAB_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -124,6 +125,8 @@ class Relation(google.protobuf.message.Message):
@property
def show_string(self) -> global___ShowString: ...
@property
def drop(self) -> global___Drop: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -156,6 +159,7 @@ class Relation(google.protobuf.message.Message):
rename_columns_by_same_length_names: global___RenameColumnsBySameLengthNames | None = ...,
rename_columns_by_name_to_name_map: global___RenameColumnsByNameToNameMap | None = ...,
show_string: global___ShowString | None = ...,
drop: global___Drop | None = ...,
fill_na: global___NAFill | None = ...,
summary: global___StatSummary | None = ...,
crosstab: global___StatCrosstab | None = ...,
Expand All @@ -172,6 +176,8 @@ class Relation(google.protobuf.message.Message):
b"crosstab",
"deduplicate",
b"deduplicate",
"drop",
b"drop",
"fill_na",
b"fill_na",
"filter",
Expand Down Expand Up @@ -227,6 +233,8 @@ class Relation(google.protobuf.message.Message):
b"crosstab",
"deduplicate",
b"deduplicate",
"drop",
b"drop",
"fill_na",
b"fill_na",
"filter",
Expand Down Expand Up @@ -293,6 +301,7 @@ class Relation(google.protobuf.message.Message):
"rename_columns_by_same_length_names",
"rename_columns_by_name_to_name_map",
"show_string",
"drop",
"fill_na",
"summary",
"crosstab",
Expand Down Expand Up @@ -961,6 +970,42 @@ class Sort(google.protobuf.message.Message):

global___Sort = Sort

class Drop(google.protobuf.message.Message):
"""Drop specified columns."""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

INPUT_FIELD_NUMBER: builtins.int
COLS_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) The input relation."""
@property
def cols(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
pyspark.sql.connect.proto.expressions_pb2.Expression
]:
"""(Required) columns to drop.

Should contain at least 1 item.
"""
def __init__(
self,
*,
input: global___Relation | None = ...,
cols: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
| None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["input", b"input"]
) -> builtins.bool: ...
def ClearField(
self, field_name: typing_extensions.Literal["cols", b"cols", "input", b"input"]
) -> None: ...

global___Drop = Drop

class Deduplicate(google.protobuf.message.Message):
"""Relation of type [[Deduplicate]] which have duplicate rows removed, could consider either only
the subset of columns or all the columns.
Expand Down
27 changes: 27 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,33 @@ def test_take(self) -> None:
df2 = self.connect.read.table(self.tbl_name_empty)
self.assertEqual(0, len(df2.take(5)))

def test_drop(self):
# SPARK-41169: test drop
query = """
SELECT * FROM VALUES
(false, 1, NULL), (false, NULL, 2), (NULL, 3, 3)
AS tab(a, b, c)
"""

cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
self.assert_eq(
cdf.drop("a").toPandas(),
sdf.drop("a").toPandas(),
)
self.assert_eq(
cdf.drop("a", "b").toPandas(),
sdf.drop("a", "b").toPandas(),
)
self.assert_eq(
cdf.drop("a", "x").toPandas(),
sdf.drop("a", "x").toPandas(),
)
self.assert_eq(
cdf.drop(cdf.a, cdf.x).toPandas(),
sdf.drop("a", "x").toPandas(),
)

def test_subquery_alias(self) -> None:
# SPARK-40938: test subquery alias.
plan_text = (
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_plan_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,22 @@ def test_sort(self):
)
self.assertEqual(plan.root.sort.is_global, False)

def test_drop(self):
# SPARK-41169: test drop
df = self.connect.readTable(table_name=self.tbl_name)

plan = df.filter(df.col_name > 3).drop("col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(
[f.unresolved_attribute.unparsed_identifier for f in plan.root.drop.cols],
["col_a", "col_b"],
)

plan = df.filter(df.col_name > 3).drop(df.col_x, "col_b")._plan.to_proto(self.connect)
self.assertEqual(
[f.unresolved_attribute.unparsed_identifier for f in plan.root.drop.cols],
["col_x", "col_b"],
)

def test_deduplicate(self):
df = self.connect.readTable(table_name=self.tbl_name)

Expand Down