-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-12935][SQL] DataFrame API for Count-Min Sketch #10911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
55d90d5
781043c
65d7e8f
fb23a24
4a40802
e64a2d7
6e29026
3ff902a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,13 +21,16 @@ | |
| import java.io.DataOutputStream; | ||
| import java.io.IOException; | ||
| import java.io.InputStream; | ||
| import java.io.ObjectInputStream; | ||
| import java.io.ObjectOutputStream; | ||
| import java.io.OutputStream; | ||
| import java.io.Serializable; | ||
| import java.io.UnsupportedEncodingException; | ||
| import java.util.Arrays; | ||
| import java.util.Random; | ||
|
|
||
| class CountMinSketchImpl extends CountMinSketch { | ||
| public static final long PRIME_MODULUS = (1L << 31) - 1; | ||
| class CountMinSketchImpl extends CountMinSketch implements Serializable { | ||
| private static final long PRIME_MODULUS = (1L << 31) - 1; | ||
|
|
||
| private int depth; | ||
| private int width; | ||
|
|
@@ -37,6 +40,9 @@ class CountMinSketchImpl extends CountMinSketch { | |
| private double eps; | ||
| private double confidence; | ||
|
|
||
| private CountMinSketchImpl() { | ||
| } | ||
|
|
||
| CountMinSketchImpl(int depth, int width, int seed) { | ||
| this.depth = depth; | ||
| this.width = width; | ||
|
|
@@ -55,16 +61,6 @@ class CountMinSketchImpl extends CountMinSketch { | |
| initTablesWith(depth, width, seed); | ||
| } | ||
|
|
||
| CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) { | ||
| this.depth = depth; | ||
| this.width = width; | ||
| this.eps = 2.0 / width; | ||
| this.confidence = 1 - 1 / Math.pow(2, depth); | ||
| this.hashA = hashA; | ||
| this.table = table; | ||
| this.totalCount = totalCount; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object other) { | ||
| if (other == this) { | ||
|
|
@@ -325,27 +321,43 @@ public void writeTo(OutputStream out) throws IOException { | |
| } | ||
|
|
||
| public static CountMinSketchImpl readFrom(InputStream in) throws IOException { | ||
| CountMinSketchImpl sketch = new CountMinSketchImpl(); | ||
| sketch.readFrom0(in); | ||
| return sketch; | ||
| } | ||
|
|
||
| private void readFrom0(InputStream in) throws IOException { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this name is quite weird...
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is actually a common naming style in java - to have the private version named xxx0 |
||
| DataInputStream dis = new DataInputStream(in); | ||
|
|
||
| // Ignores version number | ||
| dis.readInt(); | ||
| int version = dis.readInt(); | ||
| if (version != Version.V1.getVersionNumber()) { | ||
| throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")"); | ||
| } | ||
|
|
||
| long totalCount = dis.readLong(); | ||
| int depth = dis.readInt(); | ||
| int width = dis.readInt(); | ||
| this.totalCount = dis.readLong(); | ||
| this.depth = dis.readInt(); | ||
| this.width = dis.readInt(); | ||
| this.eps = 2.0 / width; | ||
| this.confidence = 1 - 1 / Math.pow(2, depth); | ||
|
|
||
| long hashA[] = new long[depth]; | ||
| this.hashA = new long[depth]; | ||
| for (int i = 0; i < depth; ++i) { | ||
| hashA[i] = dis.readLong(); | ||
| this.hashA[i] = dis.readLong(); | ||
| } | ||
|
|
||
| long table[][] = new long[depth][width]; | ||
| this.table = new long[depth][width]; | ||
| for (int i = 0; i < depth; ++i) { | ||
| for (int j = 0; j < width; ++j) { | ||
| table[i][j] = dis.readLong(); | ||
| this.table[i][j] = dis.readLong(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private void writeObject(ObjectOutputStream out) throws IOException { | ||
| this.writeTo(out); | ||
| } | ||
|
|
||
| return new CountMinSketchImpl(depth, width, totalCount, hashA, table); | ||
| private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { | ||
| this.readFrom0(in); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,11 @@ | |
| <version>1.5.6</version> | ||
| <type>jar</type> | ||
| </dependency> | ||
| <dependency> | ||
| <groupId>org.apache.spark</groupId> | ||
| <artifactId>spark-sketch_2.10</artifactId> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually this is always hard coded as
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rxin told me this. I'm not quite sure about the details though :) |
||
| <version>${project.version}</version> | ||
| </dependency> | ||
| <dependency> | ||
| <groupId>org.apache.spark</groupId> | ||
| <artifactId>spark-core_${scala.binary.version}</artifactId> | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,8 @@ import scala.collection.JavaConverters._ | |
|
|
||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.sql.execution.stat._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.util.sketch.CountMinSketch | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
|
|
@@ -309,4 +311,83 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { | |
| def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { | ||
| sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Count-min Sketch over a specified column. | ||
| * | ||
| * @param colName name of the column over which the sketch is built | ||
| * @param depth depth of the sketch | ||
| * @param width width of the sketch | ||
| * @param seed random seed | ||
| * @return a [[CountMinSketch]] over column `colName` | ||
| * @since 2.0.0 | ||
| */ | ||
| def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { | ||
| countMinSketch(Column(colName), depth, width, seed) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Count-min Sketch over a specified column. | ||
| * | ||
| * @param colName name of the column over which the sketch is built | ||
| * @param eps relative error of the sketch | ||
| * @param confidence confidence of the sketch | ||
| * @param seed random seed | ||
| * @return a [[CountMinSketch]] over column `colName` | ||
| * @since 2.0.0 | ||
| */ | ||
| def countMinSketch( | ||
| colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = { | ||
| countMinSketch(Column(colName), eps, confidence, seed) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Count-min Sketch over a specified column. | ||
| * | ||
| * @param col the column over which the sketch is built | ||
| * @param depth depth of the sketch | ||
| * @param width width of the sketch | ||
| * @param seed random seed | ||
| * @return a [[CountMinSketch]] over column `colName` | ||
| * @since 2.0.0 | ||
| */ | ||
| def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { | ||
| countMinSketch(col, CountMinSketch.create(depth, width, seed)) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Count-min Sketch over a specified column. | ||
| * | ||
| * @param col the column over which the sketch is built | ||
| * @param eps relative error of the sketch | ||
| * @param confidence confidence of the sketch | ||
| * @param seed random seed | ||
| * @return a [[CountMinSketch]] over column `colName` | ||
| * @since 2.0.0 | ||
| */ | ||
| def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { | ||
| countMinSketch(col, CountMinSketch.create(eps, confidence, seed)) | ||
| } | ||
|
|
||
| private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = { | ||
| val singleCol = df.select(col) | ||
| val colType = singleCol.schema.head.dataType | ||
|
|
||
| require( | ||
| colType == StringType || colType.isInstanceOf[IntegralType], | ||
| s"Count-min Sketch only supports string type and integral types, " + | ||
| s"and does not support type $colType." | ||
| ) | ||
|
|
||
| singleCol.rdd.aggregate(zero)( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can improve it by UDAF in the future. |
||
| (sketch: CountMinSketch, row: Row) => { | ||
| sketch.add(row.get(0)) | ||
| sketch | ||
| }, | ||
|
|
||
| (sketch1: CountMinSketch, sketch2: CountMinSketch) => { | ||
| sketch1.mergeInPlace(sketch2) | ||
| } | ||
| ) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about we move these codes into
CountMinSketch.readFrom?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I prefer not, because then we'll have to make the no-arg constructor package private instead of private.