Skip to content

Commit 8055411

Browse files
brkyvzrxin
authored andcommitted
[SPARK-7243][SQL] Contingency Tables for DataFrames
Computes a pair-wise frequency table of the given columns. Also known as cross-tabulation. cc mengxr rxin Author: Burak Yavuz <[email protected]> Closes #5842 from brkyvz/df-cont and squashes the following commits: a07c01e [Burak Yavuz] addressed comments v4.1 ae9e01d [Burak Yavuz] fix test 9106585 [Burak Yavuz] addressed comments v4.0 bced829 [Burak Yavuz] fix merge conflicts a63ad00 [Burak Yavuz] addressed comments v3.0 a0cad97 [Burak Yavuz] addressed comments v3.0 6805df8 [Burak Yavuz] addressed comments and fixed test 939b7c4 [Burak Yavuz] lint python 7f098bc [Burak Yavuz] add crosstab pyTest fd53b00 [Burak Yavuz] added python support for crosstab 27a5a81 [Burak Yavuz] implemented crosstab
1 parent fc8b581 commit 8055411

File tree

6 files changed

+160
-31
lines changed

6 files changed

+160
-31
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,26 @@ def cov(self, col1, col2):
931931
raise ValueError("col2 should be a string.")
932932
return self._jdf.stat().cov(col1, col2)
933933

934+
def crosstab(self, col1, col2):
935+
"""
936+
Computes a pair-wise frequency table of the given columns. Also known as a contingency
937+
table. The number of distinct values for each column should be less than 1e4. The first
938+
column of each row will be the distinct values of `col1` and the column names will be the
939+
distinct values of `col2`. The name of the first column will be `$col1_$col2`. Pairs that
940+
have no occurrences will have `null` as their counts.
941+
:func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases.
942+
943+
:param col1: The name of the first column. Distinct items will make the first item of
944+
each row.
945+
:param col2: The name of the second column. Distinct items will make the column names
946+
of the DataFrame.
947+
"""
948+
if not isinstance(col1, str):
949+
raise ValueError("col1 should be a string.")
950+
if not isinstance(col2, str):
951+
raise ValueError("col2 should be a string.")
952+
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
953+
934954
def freqItems(self, cols, support=None):
935955
"""
936956
Finding frequent items for columns, possibly with false positives. Using the
@@ -1423,6 +1443,11 @@ def cov(self, col1, col2):
14231443

14241444
cov.__doc__ = DataFrame.cov.__doc__
14251445

1446+
def crosstab(self, col1, col2):
1447+
return self.df.crosstab(col1, col2)
1448+
1449+
crosstab.__doc__ = DataFrame.crosstab.__doc__
1450+
14261451
def freqItems(self, cols, support=None):
14271452
return self.df.freqItems(cols, support)
14281453

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,15 @@ def test_cov(self):
405405
cov = df.stat.cov("a", "b")
406406
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
407407

408+
def test_crosstab(self):
409+
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
410+
ct = df.stat.crosstab("a", "b").collect()
411+
ct = sorted(ct, key=lambda x: x[0])
412+
for i, row in enumerate(ct):
413+
self.assertEqual(row[0], str(i))
414+
self.assertTrue(row[1], 1)
415+
self.assertTrue(row[2], 1)
416+
408417
def test_math_functions(self):
409418
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
410419
from pyspark.sql import mathfunctions as functions

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ import org.apache.spark.sql.execution.stat._
2828
final class DataFrameStatFunctions private[sql](df: DataFrame) {
2929

3030
/**
31+
* Calculate the sample covariance of two numerical columns of a DataFrame.
32+
* @param col1 the name of the first column
33+
* @param col2 the name of the second column
34+
* @return the covariance of the two columns.
35+
*/
36+
def cov(col1: String, col2: String): Double = {
37+
StatFunctions.calculateCov(df, Seq(col1, col2))
38+
}
39+
40+
/*
3141
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
3242
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
3343
* MLlib's Statistics.
@@ -53,6 +63,23 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
5363
corr(col1, col2, "pearson")
5464
}
5565

66+
/**
67+
* Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
68+
* The number of distinct values for each column should be less than 1e4. The first
69+
* column of each row will be the distinct values of `col1` and the column names will be the
70+
* distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts will be
71+
* returned as `Long`s. Pairs that have no occurrences will have `null` as their counts.
72+
*
73+
* @param col1 The name of the first column. Distinct items will make the first item of
74+
* each row.
75+
* @param col2 The name of the second column. Distinct items will make the column names
76+
* of the DataFrame.
77+
* @return A Local DataFrame containing the table
78+
*/
79+
def crosstab(col1: String, col2: String): DataFrame = {
80+
StatFunctions.crossTabulate(df, col1, col2)
81+
}
82+
5683
/**
5784
* Finding frequent items for columns, possibly with false positives. Using the
5885
* frequent element count algorithm described in
@@ -94,14 +121,4 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
94121
def freqItems(cols: Seq[String]): DataFrame = {
95122
FrequentItems.singlePassFreqItems(df, cols, 0.01)
96123
}
97-
98-
/**
99-
* Calculate the sample covariance of two numerical columns of a DataFrame.
100-
* @param col1 the name of the first column
101-
* @param col2 the name of the second column
102-
* @return the covariance of the two columns.
103-
*/
104-
def cov(col1: String, col2: String): Double = {
105-
StatFunctions.calculateCov(df, Seq(col1, col2))
106-
}
107124
}

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
package org.apache.spark.sql.execution.stat
1919

20-
import org.apache.spark.sql.catalyst.expressions.Cast
20+
import org.apache.spark.Logging
2121
import org.apache.spark.sql.{Column, DataFrame}
22-
import org.apache.spark.sql.types.{DoubleType, NumericType}
22+
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
23+
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
24+
import org.apache.spark.sql.functions._
25+
import org.apache.spark.sql.types._
2326

24-
private[sql] object StatFunctions {
27+
private[sql] object StatFunctions extends Logging {
2528

2629
/** Calculate the Pearson Correlation Coefficient for the given columns */
2730
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
@@ -95,4 +98,32 @@ private[sql] object StatFunctions {
9598
val counts = collectStatisticalData(df, cols)
9699
counts.cov
97100
}
101+
102+
/** Generate a table of frequencies for the elements of two columns. */
103+
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
104+
val tableName = s"${col1}_$col2"
105+
val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e8.toInt)
106+
if (counts.length == 1e8.toInt) {
107+
logWarning("The maximum limit of 1e8 pairs have been collected, which may not be all of " +
108+
"the pairs. Please try reducing the amount of distinct items in your columns.")
109+
}
110+
// get the distinct values of column 2, so that we can make them the column names
111+
val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap
112+
val columnSize = distinctCol2.size
113+
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
114+
s"exceed 1e4. Currently $columnSize")
115+
val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
116+
val countsRow = new GenericMutableRow(columnSize + 1)
117+
rows.foreach { row =>
118+
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
119+
}
120+
// the value of col1 is the first value, the rest are the counts
121+
countsRow.setString(0, col1Item.toString)
122+
countsRow
123+
}.toSeq
124+
val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
125+
val schema = StructType(StructField(tableName, StringType) +: headerNames)
126+
127+
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
128+
}
98129
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import java.io.Serializable;
3636
import java.util.Arrays;
37+
import java.util.Comparator;
3738
import java.util.List;
3839
import java.util.Map;
3940

@@ -178,6 +179,33 @@ public void testCreateDataFrameFromJavaBeans() {
178179
Assert.assertEquals(bean.getD().get(i), d.apply(i));
179180
}
180181
}
182+
183+
private static Comparator<Row> CrosstabRowComparator = new Comparator<Row>() {
184+
public int compare(Row row1, Row row2) {
185+
String item1 = row1.getString(0);
186+
String item2 = row2.getString(0);
187+
return item1.compareTo(item2);
188+
}
189+
};
190+
191+
@Test
192+
public void testCrosstab() {
193+
DataFrame df = context.table("testData2");
194+
DataFrame crosstab = df.stat().crosstab("a", "b");
195+
String[] columnNames = crosstab.schema().fieldNames();
196+
Assert.assertEquals(columnNames[0], "a_b");
197+
Assert.assertEquals(columnNames[1], "1");
198+
Assert.assertEquals(columnNames[2], "2");
199+
Row[] rows = crosstab.collect();
200+
Arrays.sort(rows, CrosstabRowComparator);
201+
Integer count = 1;
202+
for (Row row : rows) {
203+
Assert.assertEquals(row.get(0).toString(), count.toString());
204+
Assert.assertEquals(row.getLong(1), 1L);
205+
Assert.assertEquals(row.getLong(2), 1L);
206+
count++;
207+
}
208+
}
181209

182210
@Test
183211
public void testFrequentItems() {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,9 @@ import org.apache.spark.sql.test.TestSQLContext
2424
import org.apache.spark.sql.test.TestSQLContext.implicits._
2525

2626
class DataFrameStatSuite extends FunSuite {
27-
28-
import TestData._
27+
2928
val sqlCtx = TestSQLContext
3029
def toLetter(i: Int): String = (i + 97).toChar.toString
31-
32-
test("Frequent Items") {
33-
val rows = Seq.tabulate(1000) { i =>
34-
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
35-
}
36-
val df = rows.toDF("numbers", "letters", "negDoubles")
37-
38-
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
39-
val items = results.collect().head
40-
items.getSeq[Int](0) should contain (1)
41-
items.getSeq[String](1) should contain (toLetter(1))
42-
43-
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
44-
val items2 = singleColResults.collect().head
45-
items2.getSeq[Double](0) should contain (-1.0)
46-
}
4730

4831
test("pearson correlation") {
4932
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
@@ -76,7 +59,43 @@ class DataFrameStatSuite extends FunSuite {
7659
intercept[IllegalArgumentException] {
7760
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
7861
}
62+
val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b")
7963
val decimalRes = decimalData.stat.cov("a", "b")
8064
assert(math.abs(decimalRes) < 1e-12)
8165
}
66+
67+
test("crosstab") {
68+
val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b")
69+
val crosstab = df.stat.crosstab("a", "b")
70+
val columnNames = crosstab.schema.fieldNames
71+
assert(columnNames(0) === "a_b")
72+
assert(columnNames(1) === "0")
73+
assert(columnNames(2) === "1")
74+
val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
75+
assert(rows(0).get(0).toString === "0")
76+
assert(rows(0).getLong(1) === 2L)
77+
assert(rows(0).get(2) === null)
78+
assert(rows(1).get(0).toString === "1")
79+
assert(rows(1).getLong(1) === 1L)
80+
assert(rows(1).get(2) === null)
81+
assert(rows(2).get(0).toString === "2")
82+
assert(rows(2).getLong(1) === 2L)
83+
assert(rows(2).getLong(2) === 1L)
84+
}
85+
86+
test("Frequent Items") {
87+
val rows = Seq.tabulate(1000) { i =>
88+
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
89+
}
90+
val df = rows.toDF("numbers", "letters", "negDoubles")
91+
92+
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
93+
val items = results.collect().head
94+
items.getSeq[Int](0) should contain (1)
95+
items.getSeq[String](1) should contain (toLetter(1))
96+
97+
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
98+
val items2 = singleColResults.collect().head
99+
items2.getSeq[Double](0) should contain (-1.0)
100+
}
82101
}

0 commit comments

Comments
 (0)