Skip to content

Commit 7f098bc

Browse files
committed
add crosstab pyTest
1 parent fd53b00 commit 7f098bc

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -899,17 +899,12 @@ def crosstab(self, col1, col2):
899899
900900
:param col1: The name of the first column
901901
:param col2: The name of the second column
902-
903-
>>> df3.crosstab("age", "height").show()
904-
age_height 80 85
905-
2 1 1
906-
5 1 1
907902
"""
908903
if not isinstance(col1, str):
909904
raise ValueError("col1 should be a string.")
910905
if not isinstance(col2, str):
911906
raise ValueError("col2 should be a string.")
912-
return self._jdf.stat().crosstab(col1, col2)
907+
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
913908

914909
@ignore_unicode_prefix
915910
def withColumn(self, colName, col):

python/pyspark/sql/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,14 @@ def test_cov(self):
392392
cov = df.stat.cov("a", "b")
393393
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
394394

395+
def test_crosstab(self):
396+
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
397+
ct = df.stat.crosstab("a", "b")
398+
for i, row in enumerate(ct.collect()):
399+
self.assertEqual(row[0], str(i))
400+
self.assertTrue(row[1], 1)
401+
self.assertTrue(row[2], 1)
402+
395403
def test_math_functions(self):
396404
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
397405
from pyspark.sql import mathfunctions as functions

0 commit comments

Comments
 (0)