Skip to content

Commit 6a0b77a

Browse files
xuanyuankinggatorsmile
authored andcommitted
[SPARK-24215][PYSPARK][FOLLOW UP] Implement eager evaluation for DataFrame APIs in PySpark
## What changes were proposed in this pull request? Address comments in #21370 and add more test. ## How was this patch tested? Enhance test in pyspark/sql/test.py and DataFrameSuite Author: Yuanjian Li <[email protected]> Closes #21553 from xuanyuanking/SPARK-24215-follow.
1 parent a1a64e3 commit 6a0b77a

File tree

6 files changed

+131
-38
lines changed

6 files changed

+131
-38
lines changed

docs/configuration.md

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -456,33 +456,6 @@ Apart from these, the following properties are also available, and may be useful
456456
from JVM to Python worker for every task.
457457
</td>
458458
</tr>
459-
<tr>
460-
<td><code>spark.sql.repl.eagerEval.enabled</code></td>
461-
<td>false</td>
462-
<td>
463-
Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation,
464-
Dataset will be ran automatically. The HTML table which generated by <code>_repl_html_</code>
465-
called by notebooks like Jupyter will feedback the queries user have defined. For plain Python
466-
REPL, the output will be shown like <code>dataframe.show()</code>
467-
(see <a href="https://issues.apache.org/jira/browse/SPARK-24215">SPARK-24215</a> for more details).
468-
</td>
469-
</tr>
470-
<tr>
471-
<td><code>spark.sql.repl.eagerEval.maxNumRows</code></td>
472-
<td>20</td>
473-
<td>
474-
Default number of rows in eager evaluation output HTML table generated by <code>_repr_html_</code> or plain text,
475-
this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> is set to true.
476-
</td>
477-
</tr>
478-
<tr>
479-
<td><code>spark.sql.repl.eagerEval.truncate</code></td>
480-
<td>20</td>
481-
<td>
482-
Default number of truncate in eager evaluation output HTML table generated by <code>_repr_html_</code> or
483-
plain text, this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> set to true.
484-
</td>
485-
</tr>
486459
<tr>
487460
<td><code>spark.files</code></td>
488461
<td></td>

python/pyspark/sql/dataframe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,8 @@ def _repr_html_(self):
393393
self._support_repr_html = True
394394
if self._eager_eval:
395395
max_num_rows = max(self._max_num_rows, 0)
396-
vertical = False
397396
sock_info = self._jdf.getRowsToPython(
398-
max_num_rows, self._truncate, vertical)
397+
max_num_rows, self._truncate)
399398
rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
400399
head = rows[0]
401400
row_data = rows[1:]

python/pyspark/sql/tests.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3351,11 +3351,41 @@ def test_checking_csv_header(self):
33513351
finally:
33523352
shutil.rmtree(path)
33533353

3354-
def test_repr_html(self):
3354+
def test_repr_behaviors(self):
33553355
import re
33563356
pattern = re.compile(r'^ *\|', re.MULTILINE)
33573357
df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
3358-
self.assertEquals(None, df._repr_html_())
3358+
3359+
# test when eager evaluation is enabled and _repr_html_ will not be called
3360+
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
3361+
expected1 = """+-----+-----+
3362+
|| key|value|
3363+
|+-----+-----+
3364+
|| 1| 1|
3365+
||22222|22222|
3366+
|+-----+-----+
3367+
|"""
3368+
self.assertEquals(re.sub(pattern, '', expected1), df.__repr__())
3369+
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
3370+
expected2 = """+---+-----+
3371+
||key|value|
3372+
|+---+-----+
3373+
|| 1| 1|
3374+
||222| 222|
3375+
|+---+-----+
3376+
|"""
3377+
self.assertEquals(re.sub(pattern, '', expected2), df.__repr__())
3378+
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
3379+
expected3 = """+---+-----+
3380+
||key|value|
3381+
|+---+-----+
3382+
|| 1| 1|
3383+
|+---+-----+
3384+
|only showing top 1 row
3385+
|"""
3386+
self.assertEquals(re.sub(pattern, '', expected3), df.__repr__())
3387+
3388+
# test when eager evaluation is enabled and _repr_html_ will be called
33593389
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
33603390
expected1 = """<table border='1'>
33613391
|<tr><th>key</th><th>value</th></tr>
@@ -3381,6 +3411,18 @@ def test_repr_html(self):
33813411
|"""
33823412
self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
33833413

3414+
# test when eager evaluation is disabled and _repr_html_ will be called
3415+
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
3416+
expected = "DataFrame[key: bigint, value: string]"
3417+
self.assertEquals(None, df._repr_html_())
3418+
self.assertEquals(expected, df.__repr__())
3419+
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
3420+
self.assertEquals(None, df._repr_html_())
3421+
self.assertEquals(expected, df.__repr__())
3422+
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
3423+
self.assertEquals(None, df._repr_html_())
3424+
self.assertEquals(expected, df.__repr__())
3425+
33843426

33853427
class HiveSparkSubmitTests(SparkSubmitTests):
33863428

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,29 @@ object SQLConf {
13301330
"The size function returns null for null input if the flag is disabled.")
13311331
.booleanConf
13321332
.createWithDefault(true)
1333+
1334+
val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled")
1335+
.doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " +
1336+
"displayed if and only if the REPL supports the eager evaluation. Currently, the " +
1337+
"eager evaluation is only supported in PySpark. For the notebooks like Jupyter, " +
1338+
"the HTML table (generated by _repr_html_) will be returned. For plain Python REPL, " +
1339+
"the returned outputs are formatted like dataframe.show().")
1340+
.booleanConf
1341+
.createWithDefault(false)
1342+
1343+
val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows")
1344+
.doc("The max number of rows that are returned by eager evaluation. This only takes " +
1345+
"effect when spark.sql.repl.eagerEval.enabled is set to true. The valid range of this " +
1346+
"config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " +
1347+
"greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).")
1348+
.intConf
1349+
.createWithDefault(20)
1350+
1351+
val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate")
1352+
.doc("The max number of characters for each cell that is returned by eager evaluation. " +
1353+
"This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.")
1354+
.intConf
1355+
.createWithDefault(20)
13331356
}
13341357

13351358
/**

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,10 @@ class Dataset[T] private[sql](
236236
* @param numRows Number of rows to return
237237
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
238238
* all cells will be aligned right.
239-
* @param vertical If set to true, the rows to return do not need truncate.
240239
*/
241240
private[sql] def getRows(
242241
numRows: Int,
243-
truncate: Int,
244-
vertical: Boolean): Seq[Seq[String]] = {
242+
truncate: Int): Seq[Seq[String]] = {
245243
val newDf = toDF()
246244
val castCols = newDf.logicalPlan.output.map { col =>
247245
// Since binary types in top-level schema fields have a specific format to print,
@@ -289,7 +287,7 @@ class Dataset[T] private[sql](
289287
vertical: Boolean = false): String = {
290288
val numRows = _numRows.max(0).min(Int.MaxValue - 1)
291289
// Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
292-
val tmpRows = getRows(numRows, truncate, vertical)
290+
val tmpRows = getRows(numRows, truncate)
293291

294292
val hasMoreData = tmpRows.length - 1 > numRows
295293
val rows = tmpRows.take(numRows + 1)
@@ -3226,11 +3224,10 @@ class Dataset[T] private[sql](
32263224

32273225
private[sql] def getRowsToPython(
32283226
_numRows: Int,
3229-
truncate: Int,
3230-
vertical: Boolean): Array[Any] = {
3227+
truncate: Int): Array[Any] = {
32313228
EvaluatePython.registerPicklers()
32323229
val numRows = _numRows.max(0).min(Int.MaxValue - 1)
3233-
val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray
3230+
val rows = getRows(numRows, truncate).map(_.toArray).toArray
32343231
val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType)))
32353232
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
32363233
rows.iterator.map(toJava))

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,65 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
10441044
testData.select($"*").show(1000)
10451045
}
10461046

1047+
test("getRows: truncate = [0, 20]") {
1048+
val longString = Array.fill(21)("1").mkString
1049+
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
1050+
val expectedAnswerForFalse = Seq(
1051+
Seq("value"),
1052+
Seq("1"),
1053+
Seq("111111111111111111111"))
1054+
assert(df.getRows(10, 0) === expectedAnswerForFalse)
1055+
val expectedAnswerForTrue = Seq(
1056+
Seq("value"),
1057+
Seq("1"),
1058+
Seq("11111111111111111..."))
1059+
assert(df.getRows(10, 20) === expectedAnswerForTrue)
1060+
}
1061+
1062+
test("getRows: truncate = [3, 17]") {
1063+
val longString = Array.fill(21)("1").mkString
1064+
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
1065+
val expectedAnswerForFalse = Seq(
1066+
Seq("value"),
1067+
Seq("1"),
1068+
Seq("111"))
1069+
assert(df.getRows(10, 3) === expectedAnswerForFalse)
1070+
val expectedAnswerForTrue = Seq(
1071+
Seq("value"),
1072+
Seq("1"),
1073+
Seq("11111111111111..."))
1074+
assert(df.getRows(10, 17) === expectedAnswerForTrue)
1075+
}
1076+
1077+
test("getRows: numRows = 0") {
1078+
val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1"))
1079+
assert(testData.select($"*").getRows(0, 20) === expectedAnswer)
1080+
}
1081+
1082+
test("getRows: array") {
1083+
val df = Seq(
1084+
(Array(1, 2, 3), Array(1, 2, 3)),
1085+
(Array(2, 3, 4), Array(2, 3, 4))
1086+
).toDF()
1087+
val expectedAnswer = Seq(
1088+
Seq("_1", "_2"),
1089+
Seq("[1, 2, 3]", "[1, 2, 3]"),
1090+
Seq("[2, 3, 4]", "[2, 3, 4]"))
1091+
assert(df.getRows(10, 20) === expectedAnswer)
1092+
}
1093+
1094+
test("getRows: binary") {
1095+
val df = Seq(
1096+
("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)),
1097+
("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8))
1098+
).toDF()
1099+
val expectedAnswer = Seq(
1100+
Seq("_1", "_2"),
1101+
Seq("[31 32]", "[41 42 43 2E]"),
1102+
Seq("[33 34]", "[31 32 33 34 36]"))
1103+
assert(df.getRows(10, 20) === expectedAnswer)
1104+
}
1105+
10471106
test("showString: truncate = [0, 20]") {
10481107
val longString = Array.fill(21)("1").mkString
10491108
val df = sparkContext.parallelize(Seq("1", longString)).toDF()

0 commit comments

Comments
 (0)