Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def __init__(
self._cache: Dict[str, Any] = {}
self._session: "RemoteSparkSession" = session

def __repr__(self) -> str:
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this follows the default behavior of

def __repr__(self) -> str:
if not self._support_repr_html and self.sparkSession._jconf.isReplEagerEvalEnabled():
vertical = False
return self._jdf.showString(
self.sparkSession._jconf.replEagerEvalMaxNumRows(),
self.sparkSession._jconf.replEagerEvalTruncate(),
vertical,
)
else:
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: is this public API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will be invoked here:

In [1]: df = spark.createDataFrame([(10, 80, "Alice"), (5, None, "Bob"), (None, 10, "Tom"), (None, None, None)], schema=["age", "height", "name"])

In [2]: df
Out[2]: DataFrame[age: bigint, height: bigint, name: string]

In [3]: df.__repr__()
Out[3]: 'DataFrame[age: bigint, height: bigint, name: string]'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the example!


@classmethod
def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "DataFrame":
"""Main initialization method used to construct a new data frame with a child plan."""
Expand All @@ -137,13 +140,26 @@ def approxQuantile(self, col: Column, probabilities: Any, relativeError: Any) ->
def colRegex(self, regex: str) -> "DataFrame":
...

@property
def dtypes(self) -> List[Tuple[str, str]]:
"""Returns all column names and their data types as a list.

.. versionadded:: 3.4.0

Returns
-------
list
List of columns as tuple pairs.
"""
return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]

@property
def columns(self) -> List[str]:
"""Returns the list of columns of the current data frame."""
if self._plan is None:
return []

return self.schema().names
return self.schema.names

def sparkSession(self) -> "RemoteSparkSession":
"""Returns Spark session that created this :class:`DataFrame`.
Expand Down Expand Up @@ -736,6 +752,7 @@ def toPandas(self) -> Optional["pandas.DataFrame"]:
query = self._plan.to_proto(self._session)
return self._session._to_pandas(query)

@property
def schema(self) -> StructType:
"""Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`.

Expand Down
13 changes: 10 additions & 3 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_simple_explain_string(self):
self.assertGreater(len(result), 0)

def test_schema(self):
schema = self.connect.read.table(self.tbl_name).schema()
schema = self.connect.read.table(self.tbl_name).schema
self.assertEqual(
StructType(
[StructField("id", LongType(), True), StructField("name", StringType(), True)]
Expand Down Expand Up @@ -280,6 +280,14 @@ def test_show(self):
expected = "+---+---+\n| X| Y|\n+---+---+\n| 1| 2|\n+---+---+\n"
self.assertEqual(show_str, expected)

def test_repr(self):
# SPARK-41213: Test the __repr__ method
query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""
self.assertEqual(
self.connect.sql(query).__repr__(),
self.spark.sql(query).__repr__(),
)

def test_explain_string(self):
# SPARK-41122: test explain API.
plan_str = self.connect.sql("SELECT 1").explain(extended=True)
Expand Down Expand Up @@ -327,8 +335,7 @@ def test_alias(self) -> None:
col0 = (
self.connect.range(1, 10)
.select(col("id").alias("name", metadata={"max": 99}))
.schema()
.names[0]
.schema.names[0]
)
self.assertEqual("name", col0)

Expand Down