diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 15aa028b11b1..275a6d2668d0 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -115,6 +115,9 @@ def __init__( self._cache: Dict[str, Any] = {} self._session: "RemoteSparkSession" = session + def __repr__(self) -> str: + return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + @classmethod def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "DataFrame": """Main initialization method used to construct a new data frame with a child plan.""" @@ -137,13 +140,26 @@ def approxQuantile(self, col: Column, probabilities: Any, relativeError: Any) -> def colRegex(self, regex: str) -> "DataFrame": ... + @property + def dtypes(self) -> List[Tuple[str, str]]: + """Returns all column names and their data types as a list. + + .. versionadded:: 3.4.0 + + Returns + ------- + list + List of columns as tuple pairs. + """ + return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] + @property def columns(self) -> List[str]: """Returns the list of columns of the current data frame.""" if self._plan is None: return [] - return self.schema().names + return self.schema.names def sparkSession(self) -> "RemoteSparkSession": """Returns Spark session that created this :class:`DataFrame`. @@ -736,6 +752,7 @@ def toPandas(self) -> Optional["pandas.DataFrame"]: query = self._plan.to_proto(self._session) return self._session._to_pandas(query) + @property def schema(self) -> StructType: """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`. diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index d3de94a379f8..234f33796706 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -137,7 +137,7 @@ def test_simple_explain_string(self): self.assertGreater(len(result), 0) def test_schema(self): - schema = self.connect.read.table(self.tbl_name).schema() + schema = self.connect.read.table(self.tbl_name).schema self.assertEqual( StructType( [StructField("id", LongType(), True), StructField("name", StringType(), True)] @@ -280,6 +280,14 @@ def test_show(self): expected = "+---+---+\n| X| Y|\n+---+---+\n| 1| 2|\n+---+---+\n" self.assertEqual(show_str, expected) + def test_repr(self): + # SPARK-41213: Test the __repr__ method + query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)""" + self.assertEqual( + self.connect.sql(query).__repr__(), + self.spark.sql(query).__repr__(), + ) + def test_explain_string(self): # SPARK-41122: test explain API. plan_str = self.connect.sql("SELECT 1").explain(extended=True) @@ -327,8 +335,7 @@ def test_alias(self) -> None: col0 = ( self.connect.range(1, 10) .select(col("id").alias("name", metadata={"max": 99})) - .schema() - .names[0] + .schema.names[0] ) self.assertEqual("name", col0)