From 5e0bb6cb0f4999780f5959dca19ac02c7af3e59c Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Thu, 3 Nov 2022 15:59:09 -0700 Subject: [PATCH 1/3] [SPARK-41010][CONNECT][PYTHON] Complete Support for Except and Intersect in Python client. --- python/pyspark/sql/connect/dataframe.py | 82 ++++++++++++++++++- python/pyspark/sql/connect/plan.py | 38 +++++++-- .../tests/connect/test_connect_plan_only.py | 20 +++++ 3 files changed, 130 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index b9ba4b99ba0a..de9efb53ac03 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -291,7 +291,9 @@ def union(self, other: "DataFrame") -> "DataFrame": def unionAll(self, other: "DataFrame") -> "DataFrame": if other._plan is None: raise ValueError("Argument to Union does not contain a valid plan.") - return DataFrame.withPlan(plan.UnionAll(self._plan, other._plan), session=self._session) + return DataFrame.withPlan( + plan.SetOperation(self._plan, other._plan, "union", is_all=True), session=self._session + ) def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> "DataFrame": """Returns a new :class:`DataFrame` containing union of rows in this and another @@ -317,7 +319,83 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> if other._plan is None: raise ValueError("Argument to UnionByName does not contain a valid plan.") return DataFrame.withPlan( - plan.UnionAll(self._plan, other._plan, allowMissingColumns), session=self._session + plan.SetOperation( + self._plan, other._plan, "union", is_all=True, by_name=allowMissingColumns + ), + session=self._session, + ) + + def exceptAll(self, other: "DataFrame") -> "DataFrame": + """Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but + not in another :class:`DataFrame` while preserving duplicates. + + This is equivalent to `EXCEPT ALL` in SQL. + As standard in SQL, this function resolves columns by position (not by name). + + .. versionadded:: 2.4.0 + + Parameters + ---------- + other : :class:`DataFrame` + The other :class:`DataFrame` to compare to. + + Returns + ------- + :class:`DataFrame` + """ + return DataFrame.withPlan( + plan.SetOperation(self._plan, other._plan, "except", is_all=True), session=self._session + ) + + def intersect(self, other: "DataFrame") -> "DataFrame": + """Return a new :class:`DataFrame` containing rows only in + both this :class:`DataFrame` and another :class:`DataFrame`. + Note that any duplicates are removed. To preserve duplicates + use :func:`intersectAll`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + other : :class:`DataFrame` + Another :class:`DataFrame` that needs to be combined. + + Returns + ------- + :class:`DataFrame` + Combined DataFrame. + + Notes + ----- + This is equivalent to `INTERSECT` in SQL. + """ + return DataFrame.withPlan( + plan.SetOperation(self._plan, other._plan, "intersect", is_all=False), + session=self._session, + ) + + def intersectAll(self, other: "DataFrame") -> "DataFrame": + """Return a new :class:`DataFrame` containing rows in both this :class:`DataFrame` + and another :class:`DataFrame` while preserving duplicates. + + This is equivalent to `INTERSECT ALL` in SQL. As standard in SQL, this function + resolves columns by position (not by name). + + .. versionadded:: 2.4.0 + + Parameters + ---------- + other : :class:`DataFrame` + Another :class:`DataFrame` that needs to be combined. + + Returns + ------- + :class:`DataFrame` + Combined DataFrame. + """ + return DataFrame.withPlan( + plan.SetOperation(self._plan, other._plan, "intersect", is_all=True), + session=self._session, ) def where(self, condition: Expression) -> "DataFrame": diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index cc59a493d5ad..f015ba9b5798 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -605,21 +605,43 @@ def _repr_html_(self) -> str: """ -class UnionAll(LogicalPlan): +class SetOperation(LogicalPlan): def __init__( - self, child: Optional["LogicalPlan"], other: "LogicalPlan", by_name: bool = False + self, + child: Optional["LogicalPlan"], + other: Optional["LogicalPlan"], + set_op: str, + is_all: bool = True, + by_name: bool = False, ) -> None: super().__init__(child) self.other = other self.by_name = by_name + self.is_all = is_all + self.set_op = set_op def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: assert self._child is not None rel = proto.Relation() - rel.set_op.left_input.CopyFrom(self._child.plan(session)) - rel.set_op.right_input.CopyFrom(self.other.plan(session)) - rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION - rel.set_op.is_all = True + if self._child is not None: + rel.set_op.left_input.CopyFrom(self._child.plan(session)) + if self.other is not None: + rel.set_op.right_input.CopyFrom(self.other.plan(session)) + if self.set_op == "union": + rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION + elif self.set_op == "intersect": + rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT + elif self.set_op == "except": + rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT + else: + raise NotImplementedError( + """ + Unsupported set operation type: %s. + """ + % rel.set_op.set_op_type + ) + + rel.set_op.is_all = self.is_all rel.set_op.by_name = self.by_name return rel @@ -631,7 +653,7 @@ def print(self, indent: int = 0) -> str: o = " " * (indent + LogicalPlan.INDENT) n = indent + LogicalPlan.INDENT * 2 return ( - f"{i}UnionAll\n{o}child1=\n{self._child.print(n)}" + f"{i}SetOperation\n{o}child1=\n{self._child.print(n)}" f"\n{o}child2=\n{self.other.print(n)}" ) @@ -642,7 +664,7 @@ def _repr_html_(self) -> str: return f"""