diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index e3116ea12502..26124f6bc345 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -140,13 +140,8 @@ def columns(self) -> List[str]: """Returns the list of columns of the current data frame.""" if self._plan is None: return [] - if "columns" not in self._cache and self._plan is not None: - pdd = self.limit(0).toPandas() - if pdd is None: - raise Exception("Empty result") - # Translate to standard pytho array - self._cache["columns"] = pdd.columns.values - return self._cache["columns"] + + return self.schema().names def count(self) -> int: """Returns the number of rows in the data frame""" diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index a0f046907f73..aaf0a73699f7 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -106,6 +106,11 @@ def test_simple_read(self): # Check that the limit is applied self.assertEqual(len(data.index), 10) + def test_columns(self): + # SPARK-41036: test `columns` API for python client. + columns = self.connect.read.table(self.tbl_name).columns + self.assertEqual(["id", "name"], columns) + def test_collect(self): df = self.connect.read.table(self.tbl_name) data = df.limit(10).collect()