diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 31215b4da792b..8e34419fc1164 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -247,12 +247,12 @@ def toPandas(self) -> Optional["pandas.DataFrame"]: raise Exception("Cannot collect on empty plan.") if self._session is None: raise Exception("Cannot collect on empty session.") - query = self._plan.collect(self._session) + query = self._plan.to_proto(self._session) return self._session._to_pandas(query) def explain(self) -> str: if self._plan is not None: - query = self._plan.collect(self._session) + query = self._plan.to_proto(self._session) if self._session is None: raise Exception("Cannot analyze without RemoteSparkSession.") return self._session.analyze(query).explain_string diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index da7c5cf56981b..9351998c19561 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -80,9 +80,19 @@ def _verify(self, session: "RemoteSparkSession") -> bool: return test_plan == plan - def collect( + def to_proto( self, session: Optional["RemoteSparkSession"] = None, debug: bool = False ) -> proto.Plan: + """ + Generates connect proto plan based on this LogicalPlan. + + Parameters + ---------- + session : :class:`RemoteSparkSession`, optional. + a session that connects remote spark cluster. + debug: bool + if enabled, the proto plan will be printed. + """ plan = proto.Plan() plan.root.CopyFrom(self.plan(session)) diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 2aa686bbc3823..faba0c7cf4b73 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -45,7 +45,7 @@ def test_simple_column_expressions(self): def test_column_literals(self): df = c.DataFrame.withPlan(p.Read("table")) lit_df = df.select(fun.lit(10)) - self.assertIsNotNone(lit_df._plan.collect(None)) + self.assertIsNotNone(lit_df._plan.to_proto(None)) self.assertIsNotNone(fun.lit(10).to_plan(None)) plan = fun.lit(10).to_plan(None) diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 8fb33beb3677c..72345f9244253 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -27,13 +27,13 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): generation but do not call Spark.""" def test_simple_project(self): - plan = self.connect.readTable(table_name=self.tbl_name)._plan.collect(self.connect) + plan = self.connect.readTable(table_name=self.tbl_name)._plan.to_proto(self.connect) self.assertIsNotNone(plan.root, "Root relation must be set") self.assertIsNotNone(plan.root.read) def test_filter(self): df = self.connect.readTable(table_name=self.tbl_name) - plan = df.filter(df.col_name > 3)._plan.collect(self.connect) + plan = df.filter(df.col_name > 3)._plan.to_proto(self.connect) self.assertIsNotNone(plan.root.filter) self.assertTrue( isinstance( @@ -45,7 +45,7 @@ def test_filter(self): def test_relation_alias(self): df = self.connect.readTable(table_name=self.tbl_name) - plan = df.alias("table_alias")._plan.collect(self.connect) + plan = df.alias("table_alias")._plan.to_proto(self.connect) self.assertEqual(plan.root.common.alias, "table_alias") def test_simple_udf(self): @@ -59,7 +59,7 @@ def test_simple_udf(self): def test_all_the_plans(self): df = self.connect.readTable(table_name=self.tbl_name) df = df.select(df.col1).filter(df.col2 == 2).sort(df.col3.asc()) - plan = df._plan.collect(self.connect) + plan = df._plan.to_proto(self.connect) self.assertIsNotNone(plan.root, "Root relation must be set") self.assertIsNotNone(plan.root.read) diff --git a/python/pyspark/sql/tests/connect/test_connect_select_ops.py b/python/pyspark/sql/tests/connect/test_connect_select_ops.py index 37a64abcc5edf..3df9ec9a3bca5 100644 --- a/python/pyspark/sql/tests/connect/test_connect_select_ops.py +++ b/python/pyspark/sql/tests/connect/test_connect_select_ops.py @@ -24,7 +24,7 @@ class SparkConnectToProtoSuite(PlanOnlyTestFixture): def test_select_with_literal(self): df = DataFrame.withPlan(Read("table")) - self.assertIsNotNone(df.select(col("name"))._plan.collect()) + self.assertIsNotNone(df.select(col("name"))._plan.to_proto()) self.assertRaises(InputValidationError, df.select, "name") def test_join_with_join_type(self): @@ -39,7 +39,7 @@ def test_join_with_join_type(self): ("leftanti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI), ("leftsemi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI), ]: - joined_df = df_left.join(df_right, on=col("name"), how=join_type_str)._plan.collect() + joined_df = df_left.join(df_right, on=col("name"), how=join_type_str)._plan.to_proto() self.assertEqual(joined_df.root.join.join_type, join_type)