Skip to content

Commit 27a5a81

Browse files
committed
implemented crosstab
1 parent 7630213 commit 27a5a81

File tree

4 files changed

+88
-1
lines changed

4 files changed

+88
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import org.apache.spark.annotation.Experimental
21-
import org.apache.spark.sql.execution.stat.FrequentItems
21+
import org.apache.spark.sql.execution.stat.{ContingencyTable, FrequentItems}
2222

2323
/**
2424
* :: Experimental ::
@@ -27,6 +27,20 @@ import org.apache.spark.sql.execution.stat.FrequentItems
2727
@Experimental
2828
final class DataFrameStatFunctions private[sql](df: DataFrame) {
2929

30+
/**
31+
* Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
32+
* The number of distinct values for each column should be less than Int.MaxValue. The first
33+
* column of each row will be the distinct values of `col1` and the column names will be the
34+
* distinct values of `col2` sorted in lexicographical order. Counts will be returned as `Long`s.
35+
*
36+
* @param col1 The name of the first column.
37+
* @param col2 The name of the second column.
38+
* @return A Local DataFrame containing the table
39+
*/
40+
def crosstab(col1: String, col2: String): DataFrame = {
41+
ContingencyTable.crossTabulate(df, col1, col2)
42+
}
43+
3044
/**
3145
* Finding frequent items for columns, possibly with false positives. Using the
3246
* frequent element count algorithm described in
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package org.apache.spark.sql.execution.stat
2+
3+
import org.apache.spark.sql.{Row, DataFrame}
4+
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
5+
import org.apache.spark.sql.types._
6+
import org.apache.spark.sql.functions._
7+
8+
9+
private[sql] object ContingencyTable {
10+
11+
/** Generate a table of frequencies for the elements of two columns. */
12+
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
13+
val tableName = s"${col1}_$col2"
14+
val distinctVals = df.select(countDistinct(col1), countDistinct(col2)).collect().head
15+
val distinctCol1 = distinctVals.getLong(0)
16+
val distinctCol2 = distinctVals.getLong(1)
17+
18+
require(distinctCol1 < Int.MaxValue, s"The number of distinct values for $col1, can't " +
19+
s"exceed Int.MaxValue. Currently $distinctCol1")
20+
require(distinctCol2 < Int.MaxValue, s"The number of distinct values for $col2, can't " +
21+
s"exceed Int.MaxValue. Currently $distinctCol2")
22+
// Aggregate the counts for the two columns
23+
val allCounts =
24+
df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).orderBy(col1, col2).collect()
25+
// Pivot the table
26+
val pivotedTable = allCounts.grouped(distinctCol2.toInt).toArray
27+
// Get the column names (distinct values of col2)
28+
val headerNames = pivotedTable.head.map(r => StructField(r.get(1).toString, LongType))
29+
val schema = StructType(StructField(tableName, StringType) +: headerNames)
30+
val table = pivotedTable.map { rows =>
31+
// the value of col1 is the first value, the rest are the counts
32+
val rowValues = rows.head.get(0).toString +: rows.map(_.getLong(2))
33+
Row(rowValues:_*)
34+
}
35+
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
36+
}
37+
38+
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,24 @@ public void testCreateDataFrameFromJavaBeans() {
178178
Assert.assertEquals(bean.getD().get(i), d.apply(i));
179179
}
180180
}
181+
182+
@Test
183+
public void testCrosstab() {
184+
DataFrame df = context.table("testData2");
185+
DataFrame crosstab = df.stat().crosstab("a", "b");
186+
String[] columnNames = crosstab.schema().fieldNames();
187+
Assert.assertEquals(columnNames[0], "a_b");
188+
Assert.assertEquals(columnNames[1], "1");
189+
Assert.assertEquals(columnNames[2], "2");
190+
Row[] rows = crosstab.collect();
191+
Integer count = 1;
192+
for (Row row : rows) {
193+
Assert.assertEquals(row.get(0).toString(), count.toString());
194+
Assert.assertEquals(row.getLong(1), 1L);
195+
Assert.assertEquals(row.getLong(2), 1L);
196+
count++;
197+
}
198+
}
181199

182200
@Test
183201
public void testFrequentItems() {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,26 @@ import org.apache.spark.sql.test.TestSQLContext
2424
import org.apache.spark.sql.test.TestSQLContext.implicits._
2525

2626
class DataFrameStatSuite extends FunSuite {
27+
import TestData._
2728

2829
val sqlCtx = TestSQLContext
2930

31+
test("crosstab") {
32+
val crosstab = testData2.stat.crosstab("a", "b")
33+
val columnNames = crosstab.schema.fieldNames
34+
assert(columnNames(0) === "a_b")
35+
assert(columnNames(1) === "1")
36+
assert(columnNames(2) === "2")
37+
val rows: Array[Row] = crosstab.collect()
38+
var count: Integer = 1
39+
rows.foreach { row =>
40+
assert(row.get(0).toString === count.toString)
41+
assert(row.getLong(1) === 1L)
42+
assert(row.getLong(2) === 1L)
43+
count += 1
44+
}
45+
}
46+
3047
test("Frequent Items") {
3148
def toLetter(i: Int): String = (i + 96).toChar.toString
3249
val rows = Array.tabulate(1000) { i =>

0 commit comments

Comments
 (0)