Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ public abstract class BloomFilter {
public enum Version {
/**
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
* - Version number, always 1 (32 bit)
* - Total number of words of the underlying bit array (32 bit)
* - The words/longs (numWords * 64 bit)
* - Number of hash functions (32 bit)
* <ul>
* <li>Version number, always 1 (32 bit)</li>
* <li>Total number of words of the underlying bit array (32 bit)</li>
* <li>The words/longs (numWords * 64 bit)</li>
* <li>Number of hash functions (32 bit)</li>
* </ul>
*/
V1(1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,22 @@ abstract public class CountMinSketch {
public enum Version {
/**
* {@code CountMinSketch} binary format version 1 (all values written in big-endian order):
* - Version number, always 1 (32 bit)
* - Total count of added items (64 bit)
* - Depth (32 bit)
* - Width (32 bit)
* - Hash functions (depth * 64 bit)
* - Count table
* - Row 0 (width * 64 bit)
* - Row 1 (width * 64 bit)
* - ...
* - Row depth - 1 (width * 64 bit)
* <ul>
* <li>Version number, always 1 (32 bit)</li>
* <li>Total count of added items (64 bit)</li>
* <li>Depth (32 bit)</li>
* <li>Width (32 bit)</li>
* <li>Hash functions (depth * 64 bit)</li>
* <li>
* Count table
* <ul>
* <li>Row 0 (width * 64 bit)</li>
* <li>Row 1 (width * 64 bit)</li>
* <li>...</li>
* <li>Row {@code depth - 1} (width * 64 bit)</li>
* </ul>
* </li>
* </ul>
*/
V1(1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;

class CountMinSketchImpl extends CountMinSketch {
public static final long PRIME_MODULUS = (1L << 31) - 1;
class CountMinSketchImpl extends CountMinSketch implements Serializable {
private static final long PRIME_MODULUS = (1L << 31) - 1;

private int depth;
private int width;
Expand All @@ -37,6 +40,9 @@ class CountMinSketchImpl extends CountMinSketch {
private double eps;
private double confidence;

private CountMinSketchImpl() {
}

CountMinSketchImpl(int depth, int width, int seed) {
this.depth = depth;
this.width = width;
Expand All @@ -55,16 +61,6 @@ class CountMinSketchImpl extends CountMinSketch {
initTablesWith(depth, width, seed);
}

CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) {
this.depth = depth;
this.width = width;
this.eps = 2.0 / width;
this.confidence = 1 - 1 / Math.pow(2, depth);
this.hashA = hashA;
this.table = table;
this.totalCount = totalCount;
}

@Override
public boolean equals(Object other) {
if (other == this) {
Expand Down Expand Up @@ -325,27 +321,43 @@ public void writeTo(OutputStream out) throws IOException {
}

public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
CountMinSketchImpl sketch = new CountMinSketchImpl();
sketch.readFrom0(in);
return sketch;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we move these codes into CountMinSketch.readFrom?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I prefer not, because then we'll have to make the no-arg constructor package private instead of private.

}

private void readFrom0(InputStream in) throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this name is quite weird...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is actually a common naming style in java - to have the private version named xxx0

DataInputStream dis = new DataInputStream(in);

// Ignores version number
dis.readInt();
int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")");
}

long totalCount = dis.readLong();
int depth = dis.readInt();
int width = dis.readInt();
this.totalCount = dis.readLong();
this.depth = dis.readInt();
this.width = dis.readInt();
this.eps = 2.0 / width;
this.confidence = 1 - 1 / Math.pow(2, depth);

long hashA[] = new long[depth];
this.hashA = new long[depth];
for (int i = 0; i < depth; ++i) {
hashA[i] = dis.readLong();
this.hashA[i] = dis.readLong();
}

long table[][] = new long[depth][width];
this.table = new long[depth][width];
for (int i = 0; i < depth; ++i) {
for (int j = 0; j < width; ++j) {
table[i][j] = dis.readLong();
this.table[i][j] = dis.readLong();
}
}
}

private void writeObject(ObjectOutputStream out) throws IOException {
this.writeTo(out);
}

return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
this.readFrom0(in);
}
}
5 changes: 5 additions & 0 deletions sql/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
<version>1.5.6</version>
<type>jar</type>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sketch_2.10</artifactId>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use scala.binary.version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is always hard coded as _2.10 to make publishing easier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin told me this. I'm not quite sure about the details though :)

<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.CountMinSketch

/**
* :: Experimental ::
Expand Down Expand Up @@ -309,4 +311,83 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
}

/**
* Builds a Count-min Sketch over a specified column.
*
* @param colName name of the column over which the sketch is built
* @param depth depth of the sketch
* @param width width of the sketch
* @param seed random seed
* @return a [[CountMinSketch]] over column `colName`
* @since 2.0.0
*/
def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = {
countMinSketch(Column(colName), depth, width, seed)
}

/**
* Builds a Count-min Sketch over a specified column.
*
* @param colName name of the column over which the sketch is built
* @param eps relative error of the sketch
* @param confidence confidence of the sketch
* @param seed random seed
* @return a [[CountMinSketch]] over column `colName`
* @since 2.0.0
*/
def countMinSketch(
colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
countMinSketch(Column(colName), eps, confidence, seed)
}

/**
* Builds a Count-min Sketch over a specified column.
*
* @param col the column over which the sketch is built
* @param depth depth of the sketch
* @param width width of the sketch
* @param seed random seed
* @return a [[CountMinSketch]] over column `colName`
* @since 2.0.0
*/
def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = {
countMinSketch(col, CountMinSketch.create(depth, width, seed))
}

/**
* Builds a Count-min Sketch over a specified column.
*
* @param col the column over which the sketch is built
* @param eps relative error of the sketch
* @param confidence confidence of the sketch
* @param seed random seed
* @return a [[CountMinSketch]] over column `colName`
* @since 2.0.0
*/
def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
countMinSketch(col, CountMinSketch.create(eps, confidence, seed))
}

private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = {
val singleCol = df.select(col)
val colType = singleCol.schema.head.dataType

require(
colType == StringType || colType.isInstanceOf[IntegralType],
s"Count-min Sketch only supports string type and integral types, " +
s"and does not support type $colType."
)

singleCol.rdd.aggregate(zero)(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can improve it by UDAF in the future.

(sketch: CountMinSketch, row: Row) => {
sketch.add(row.get(0))
sketch
},

(sketch1: CountMinSketch, sketch2: CountMinSketch) => {
sketch1.mergeInPlace(sketch2)
}
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.types.*;
import org.apache.spark.util.sketch.CountMinSketch;
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.types.DataTypes.*;

public class JavaDataFrameSuite {
Expand Down Expand Up @@ -321,4 +322,29 @@ public void testTextLoad() {
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
}

@Test
public void testCountMinSketch() {
DataFrame df = context.range(1000);

CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
Assert.assertEquals(sketch1.totalCount(), 1000);
Assert.assertEquals(sketch1.depth(), 10);
Assert.assertEquals(sketch1.width(), 20);

CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42);
Assert.assertEquals(sketch2.totalCount(), 1000);
Assert.assertEquals(sketch2.depth(), 10);
Assert.assertEquals(sketch2.width(), 20);

CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42);
Assert.assertEquals(sketch3.totalCount(), 1000);
Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4);
Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3);

CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42);
Assert.assertEquals(sketch4.totalCount(), 1000);
Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4);
Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package org.apache.spark.sql

import java.util.Random

import org.scalatest.Matchers._

import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.DoubleType

class DataFrameStatSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -210,4 +213,37 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
sampled.groupBy("key").count().orderBy("key"),
Seq(Row(0, 6), Row(1, 11)))
}

// This test case only verifies that `DataFrame.countMinSketch()` methods do return
// `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in
// `CountMinSketchSuite` in project spark-sketch.
test("countMinSketch") {
val df = sqlContext.range(1000)

val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42)
assert(sketch1.totalCount() === 1000)
assert(sketch1.depth() === 10)
assert(sketch1.width() === 20)

val sketch2 = df.stat.countMinSketch($"id", depth = 10, width = 20, seed = 42)
assert(sketch2.totalCount() === 1000)
assert(sketch2.depth() === 10)
assert(sketch2.width() === 20)

val sketch3 = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42)
assert(sketch3.totalCount() === 1000)
assert(sketch3.relativeError() === 0.001)
assert(sketch3.confidence() === 0.99 +- 5e-3)

val sketch4 = df.stat.countMinSketch($"id", eps = 0.001, confidence = 0.99, seed = 42)
assert(sketch4.totalCount() === 1000)
assert(sketch4.relativeError() === 0.001 +- 1e04)
assert(sketch4.confidence() === 0.99 +- 5e-3)

intercept[IllegalArgumentException] {
df.select('id cast DoubleType as 'id)
.stat
.countMinSketch('id, depth = 10, width = 20, seed = 42)
}
}
}