-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-41169][CONNECT][PYTHON] Implement DataFrame.drop
#38686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,7 @@ message Relation { | |
| RenameColumnsBySameLengthNames rename_columns_by_same_length_names = 18; | ||
| RenameColumnsByNameToNameMap rename_columns_by_name_to_name_map = 19; | ||
| ShowString show_string = 20; | ||
| Drop drop = 21; | ||
|
|
||
| // NA functions | ||
| NAFill fill_na = 90; | ||
|
|
@@ -252,6 +253,19 @@ message Sort { | |
| } | ||
| } | ||
|
|
||
|
|
||
| // Drop specified columns. | ||
| message Drop { | ||
| // (Required) The input relation. | ||
| Relation input = 1; | ||
|
|
||
| // (Required) columns to drop. | ||
| // | ||
| // Should contain at least 1 item. | ||
| repeated Expression cols = 2; | ||
|
||
| } | ||
|
|
||
|
|
||
| // Relation of type [[Deduplicate]] which have duplicate rows removed, could consider either only | ||
| // the subset of columns or all the columns. | ||
| message Deduplicate { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ import com.google.common.collect.{Lists, Maps} | |
| import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} | ||
| import org.apache.spark.connect.proto | ||
| import org.apache.spark.connect.proto.WriteOperation | ||
| import org.apache.spark.sql.{Dataset, SparkSession} | ||
| import org.apache.spark.sql.{Column, Dataset, SparkSession} | ||
| import org.apache.spark.sql.catalyst.AliasIdentifier | ||
| import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} | ||
| import org.apache.spark.sql.catalyst.expressions | ||
|
|
@@ -69,6 +69,7 @@ class SparkConnectPlanner(session: SparkSession) { | |
| case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate) | ||
| case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp) | ||
| case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) | ||
| case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop) | ||
| case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) | ||
| case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) | ||
| case proto.Relation.RelTypeCase.LOCAL_RELATION => | ||
|
|
@@ -523,6 +524,19 @@ class SparkConnectPlanner(session: SparkSession) { | |
| sameOrderExpressions = Seq.empty) | ||
| } | ||
|
|
||
| private def transformDrop(rel: proto.Drop): LogicalPlan = { | ||
| assert(rel.getColsCount > 0, s"cols must contains at least 1 item!") | ||
|
|
||
| val cols = rel.getColsList.asScala.toArray.map { expr => | ||
| Column(transformExpression(expr)) | ||
|
||
| } | ||
|
|
||
| Dataset | ||
| .ofRows(session, transformRelation(rel.getInput)) | ||
| .drop(cols.head, cols.tail: _*) | ||
| .logicalPlan | ||
| } | ||
|
|
||
| private def transformAggregate(rel: proto.Aggregate): LogicalPlan = { | ||
| assert(rel.hasInput) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -148,6 +148,23 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { | |
| comparePlans(connectPlan2, sparkPlan2) | ||
| } | ||
|
|
||
| test("SPARK-41169: Test drop") { | ||
| // single column | ||
| val connectPlan = connectTestRelation.drop("id") | ||
| val sparkPlan = sparkTestRelation.drop("id") | ||
| comparePlans(connectPlan, sparkPlan) | ||
|
|
||
| // all columns | ||
| val connectPlan2 = connectTestRelation.drop("id", "name") | ||
| val sparkPlan2 = sparkTestRelation.drop("id", "name") | ||
| comparePlans(connectPlan2, sparkPlan2) | ||
|
|
||
| // non-existing column | ||
|
||
| val connectPlan3 = connectTestRelation.drop("id2", "name") | ||
| val sparkPlan3 = sparkTestRelation.drop("id2", "name") | ||
| comparePlans(connectPlan3, sparkPlan3) | ||
| } | ||
|
|
||
| test("SPARK-40809: column alias") { | ||
| // Simple Test. | ||
| val connectPlan = connectTestRelation.select("id".protoAttr.as("id2")) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -257,10 +257,21 @@ def distinct(self) -> "DataFrame": | |
| ) | ||
|
|
||
| def drop(self, *cols: "ColumnOrString") -> "DataFrame": | ||
|
||
| all_cols = self.columns | ||
| dropped = set([c.name() if isinstance(c, Column) else self[c].name() for c in cols]) | ||
| dropped_cols = filter(lambda x: x in dropped, all_cols) | ||
| return DataFrame.withPlan(plan.Project(self._plan, *dropped_cols), session=self._session) | ||
| _cols = list(cols) | ||
| if any(not isinstance(c, (str, Column)) for c in _cols): | ||
| raise TypeError( | ||
| f"'cols' must contains strings or Columns, but got {type(cols).__name__}" | ||
| ) | ||
| if len(_cols) == 0: | ||
| raise ValueError("'cols' must be non-empty") | ||
|
|
||
| return DataFrame.withPlan( | ||
| plan.Drop( | ||
| child=self._plan, | ||
| columns=_cols, | ||
| ), | ||
| session=self._session, | ||
| ) | ||
|
|
||
| def filter(self, condition: Expression) -> "DataFrame": | ||
| return DataFrame.withPlan( | ||
|
|
||
Large diffs are not rendered by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does drop actually support arbitrary expressions? Shouldn't this be a repeated unresolved attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dataset.drop takes arbitrary expressions into account
spark/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Lines 2952 to 2957 in 3b4faaf