From 6bed316481147a8c2c4693a11344f7845a8fc498 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 17 Oct 2022 13:00:19 -0700 Subject: [PATCH 1/4] test lint --- python/pyspark/sql/connect/plan.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 10f19aa00f23..b52a8be06d73 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -49,6 +49,7 @@ class LogicalPlan(object): def __init__(self, child: Optional["LogicalPlan"]) -> None: self._child = child + def unresolved_attr(self, *colNames: str) -> proto.Expression: """Creates an unresolved attribute from a column name.""" exp = proto.Expression() From 19c2ae9ea8075f4fa3855550480a0f29811790f9 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 17 Oct 2022 13:09:41 -0700 Subject: [PATCH 2/4] [SPARK-40796][BUILD][FOLLOW-UP] Fix `unused "type: ignore" comment`. --- python/pyspark/sql/connect/plan.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index b52a8be06d73..eabc792d8895 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -49,7 +49,6 @@ class LogicalPlan(object): def __init__(self, child: Optional["LogicalPlan"]) -> None: self._child = child - def unresolved_attr(self, *colNames: str) -> proto.Expression: """Creates an unresolved attribute from a column name.""" exp = proto.Expression() @@ -323,13 +322,13 @@ def _convert_measure( ) -> proto.Aggregate.AggregateFunction: exp, fun = m measure = proto.Aggregate.AggregateFunction() - measure.function.name = fun # type: ignore[attr-defined] + measure.function.name = fun if type(exp) is str: - measure.function.arguments.append( # type: ignore[attr-defined] + measure.function.arguments.append( self.unresolved_attr(exp) ) else: - measure.function.arguments.append( # type: ignore[attr-defined] + measure.function.arguments.append( cast(Expression, exp).to_plan(session) ) return measure @@ -340,13 +339,13 @@ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: agg = proto.Relation() agg.aggregate.input.CopyFrom(self._child.plan(session)) - agg.aggregate.measures.extend( # type: ignore[attr-defined] + agg.aggregate.measures.extend( list(map(lambda x: self._convert_measure(x, session), self.measures)) ) - gs = proto.Aggregate.GroupingSet() # type: ignore[attr-defined] + gs = proto.Aggregate.GroupingSet() gs.aggregate_expressions.extend(groupings) - agg.aggregate.grouping_sets.append(gs) # type: ignore[attr-defined] + agg.aggregate.grouping_sets.append(gs) return agg def print(self, indent: int = 0) -> str: From d1e87bed8bfc998932f9c17009bdd14b0876b3b4 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 17 Oct 2022 14:23:48 -0700 Subject: [PATCH 3/4] update --- python/pyspark/sql/connect/plan.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index eabc792d8895..486778b9d374 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -324,13 +324,9 @@ def _convert_measure( measure = proto.Aggregate.AggregateFunction() measure.function.name = fun if type(exp) is str: - measure.function.arguments.append( - self.unresolved_attr(exp) - ) + measure.function.arguments.append(self.unresolved_attr(exp)) else: - measure.function.arguments.append( - cast(Expression, exp).to_plan(session) - ) + measure.function.arguments.append(cast(Expression, exp).to_plan(session)) return measure def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: From 0794ee1087e6be93c4fac7fa2cc3f49adccc4a2a Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 17 Oct 2022 16:28:01 -0700 Subject: [PATCH 4/4] update --- python/pyspark/sql/connect/plan.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 486778b9d374..da7c5cf56981 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -322,11 +322,11 @@ def _convert_measure( ) -> proto.Aggregate.AggregateFunction: exp, fun = m measure = proto.Aggregate.AggregateFunction() - measure.function.name = fun + measure.name = fun if type(exp) is str: - measure.function.arguments.append(self.unresolved_attr(exp)) + measure.arguments.append(self.unresolved_attr(exp)) else: - measure.function.arguments.append(cast(Expression, exp).to_plan(session)) + measure.arguments.append(cast(Expression, exp).to_plan(session)) return measure def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: @@ -335,13 +335,11 @@ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: agg = proto.Relation() agg.aggregate.input.CopyFrom(self._child.plan(session)) - agg.aggregate.measures.extend( + agg.aggregate.result_expressions.extend( list(map(lambda x: self._convert_measure(x, session), self.measures)) ) - gs = proto.Aggregate.GroupingSet() - gs.aggregate_expressions.extend(groupings) - agg.aggregate.grouping_sets.append(gs) + agg.aggregate.grouping_expressions.extend(groupings) return agg def print(self, indent: int = 0) -> str: