-
Notifications
You must be signed in to change notification settings - Fork 29.3k
[SPARK-12938][SQL] DataFrame API for Bloom filter #10937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,10 +19,10 @@ | |
|
|
||
| import java.io.*; | ||
|
|
||
| public class BloomFilterImpl extends BloomFilter { | ||
| public class BloomFilterImpl extends BloomFilter implements Serializable { | ||
|
|
||
| private final int numHashFunctions; | ||
| private final BitArray bits; | ||
| private int numHashFunctions; | ||
| private BitArray bits; | ||
|
|
||
| BloomFilterImpl(int numHashFunctions, long numBits) { | ||
| this(new BitArray(numBits), numHashFunctions); | ||
|
|
@@ -33,6 +33,8 @@ private BloomFilterImpl(BitArray bits, int numHashFunctions) { | |
| this.numHashFunctions = numHashFunctions; | ||
| } | ||
|
|
||
| private BloomFilterImpl() {} | ||
|
|
||
| @Override | ||
| public boolean equals(Object other) { | ||
| if (other == this) { | ||
|
|
@@ -63,55 +65,90 @@ public long bitSize() { | |
| return bits.bitSize(); | ||
| } | ||
|
|
||
| private static long hashObjectToLong(Object item) { | ||
| if (item instanceof String) { | ||
| try { | ||
| byte[] bytes = ((String) item).getBytes("utf-8"); | ||
| return hashBytesToLong(bytes); | ||
| } catch (UnsupportedEncodingException e) { | ||
| throw new RuntimeException("Only support utf-8 string", e); | ||
| } | ||
| private byte[] getBytesFromUTF8String(Object s) { | ||
| try { | ||
| return ((String) s).getBytes("utf-8"); | ||
| } catch (UnsupportedEncodingException e) { | ||
| throw new RuntimeException("Only support utf-8 string", e); | ||
| } | ||
| } | ||
|
|
||
| private long integralToLong(Object i) { | ||
| long longValue; | ||
|
|
||
| if (i instanceof Long) { | ||
| longValue = (Long) i; | ||
| } else if (i instanceof Integer) { | ||
| longValue = ((Integer) i).longValue(); | ||
| } else if (i instanceof Short) { | ||
| longValue = ((Short) i).longValue(); | ||
| } else if (i instanceof Byte) { | ||
| longValue = ((Byte) i).longValue(); | ||
| } else { | ||
| long longValue; | ||
|
|
||
| if (item instanceof Long) { | ||
| longValue = (Long) item; | ||
| } else if (item instanceof Integer) { | ||
| longValue = ((Integer) item).longValue(); | ||
| } else if (item instanceof Short) { | ||
| longValue = ((Short) item).longValue(); | ||
| } else if (item instanceof Byte) { | ||
| longValue = ((Byte) item).longValue(); | ||
| } else { | ||
| throw new IllegalArgumentException( | ||
| "Support for " + item.getClass().getName() + " not implemented" | ||
| ); | ||
| } | ||
| throw new IllegalArgumentException( | ||
| "Support for " + i.getClass().getName() + " not implemented" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Unsupported data type " + i.getClass().getName() |
||
| ); | ||
| } | ||
|
|
||
| int h1 = Murmur3_x86_32.hashLong(longValue, 0); | ||
| int h2 = Murmur3_x86_32.hashLong(longValue, h1); | ||
| return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); | ||
| return longValue; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean put(Object item) { | ||
| if (item instanceof String) { | ||
| return putBinary(getBytesFromUTF8String(item)); | ||
| } else { | ||
| return putLong(integralToLong(item)); | ||
| } | ||
| } | ||
|
|
||
| private static long hashBytesToLong(byte[] bytes) { | ||
| @Override | ||
| public boolean putBinary(byte[] bytes) { | ||
| int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); | ||
| int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); | ||
| return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); | ||
|
|
||
| long bitSize = bits.bitSize(); | ||
| boolean bitsChanged = false; | ||
| for (int i = 1; i <= numHashFunctions; i++) { | ||
| int combinedHash = h1 + (i * h2); | ||
| // Flip all the bits if it's negative (guaranteed positive number) | ||
| if (combinedHash < 0) { | ||
| combinedHash = ~combinedHash; | ||
| } | ||
| bitsChanged |= bits.set(combinedHash % bitSize); | ||
| } | ||
| return bitsChanged; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean put(Object item) { | ||
| private boolean mightContainBinary(byte[] bytes) { | ||
| int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); | ||
| int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); | ||
|
|
||
| long bitSize = bits.bitSize(); | ||
| for (int i = 1; i <= numHashFunctions; i++) { | ||
| int combinedHash = h1 + (i * h2); | ||
| // Flip all the bits if it's negative (guaranteed positive number) | ||
| if (combinedHash < 0) { | ||
| combinedHash = ~combinedHash; | ||
| } | ||
| if (!bits.get(combinedHash % bitSize)) { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| // Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash | ||
| // values by `h1 + i * h2` with 1 <= i <= numHashFunctions. | ||
| // Note that `CountMinSketch` use a different strategy for long type, it hash the input long | ||
| // element with every i to produce n hash values. | ||
| long hash64 = hashObjectToLong(item); | ||
| int h1 = (int) (hash64 >> 32); | ||
| int h2 = (int) hash64; | ||
| @Override | ||
| public boolean putLong(long l) { | ||
| // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n | ||
| // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. | ||
| // Note that `CountMinSketch` use a different strategy, it hash the input long element with | ||
| // every i to produce n hash values. | ||
| // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? | ||
| int h1 = Murmur3_x86_32.hashLong(l, 0); | ||
| int h2 = Murmur3_x86_32.hashLong(l, h1); | ||
|
|
||
| long bitSize = bits.bitSize(); | ||
| boolean bitsChanged = false; | ||
| for (int i = 1; i <= numHashFunctions; i++) { | ||
| int combinedHash = h1 + (i * h2); | ||
|
|
@@ -124,13 +161,11 @@ public boolean put(Object item) { | |
| return bitsChanged; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean mightContain(Object item) { | ||
| long bitSize = bits.bitSize(); | ||
| long hash64 = hashObjectToLong(item); | ||
| int h1 = (int) (hash64 >> 32); | ||
| int h2 = (int) hash64; | ||
| private boolean mightContainLong(long l) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should open this one as an api too
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and have a mightContain for string as well
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (can you guys also fix count-min sketch to make them consistent?) |
||
| int h1 = Murmur3_x86_32.hashLong(l, 0); | ||
| int h2 = Murmur3_x86_32.hashLong(l, h1); | ||
|
|
||
| long bitSize = bits.bitSize(); | ||
| for (int i = 1; i <= numHashFunctions; i++) { | ||
| int combinedHash = h1 + (i * h2); | ||
| // Flip all the bits if it's negative (guaranteed positive number) | ||
|
|
@@ -144,6 +179,15 @@ public boolean mightContain(Object item) { | |
| return true; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean mightContain(Object item) { | ||
| if (item instanceof String) { | ||
| return mightContainBinary(getBytesFromUTF8String(item)); | ||
| } else { | ||
| return mightContainLong(integralToLong(item)); | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public boolean isCompatible(BloomFilter other) { | ||
| if (other == null) { | ||
|
|
@@ -191,18 +235,33 @@ public void writeTo(OutputStream out) throws IOException { | |
| DataOutputStream dos = new DataOutputStream(out); | ||
|
|
||
| dos.writeInt(Version.V1.getVersionNumber()); | ||
| bits.writeTo(dos); | ||
| dos.writeInt(numHashFunctions); | ||
| bits.writeTo(dos); | ||
| } | ||
|
|
||
| public static BloomFilterImpl readFrom(InputStream in) throws IOException { | ||
| private void readFrom0(InputStream in) throws IOException { | ||
| DataInputStream dis = new DataInputStream(in); | ||
|
|
||
| int version = dis.readInt(); | ||
| if (version != Version.V1.getVersionNumber()) { | ||
| throw new IOException("Unexpected Bloom filter version number (" + version + ")"); | ||
| } | ||
|
|
||
| return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); | ||
| this.numHashFunctions = dis.readInt(); | ||
| this.bits = BitArray.readFrom(dis); | ||
| } | ||
|
|
||
| public static BloomFilterImpl readFrom(InputStream in) throws IOException { | ||
| BloomFilterImpl filter = new BloomFilterImpl(); | ||
| filter.readFrom0(in); | ||
| return filter; | ||
| } | ||
|
|
||
| private void writeObject(ObjectOutputStream out) throws IOException { | ||
| writeTo(out); | ||
| } | ||
|
|
||
| private void readObject(ObjectInputStream in) throws IOException { | ||
| readFrom0(in); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,10 @@ import java.{lang => jl, util => ju} | |
| import scala.collection.JavaConverters._ | ||
|
|
||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.execution.stat._ | ||
| import org.apache.spark.sql.types.{IntegralType, StringType} | ||
| import org.apache.spark.util.sketch.BloomFilter | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
|
|
@@ -309,4 +312,77 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { | |
| def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { | ||
| sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Bloom filter over a specified column. | ||
| * | ||
| * @param colName name of the column over which the filter is built | ||
| * @param expectedNumItems expected number of items which will be put into the filter. | ||
| * @param fpp expected false positive probability of the filter. | ||
| * | ||
| * @since 2.0.0 | ||
| */ | ||
| def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { | ||
| buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, fpp)) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Bloom filter over a specified column. | ||
| * | ||
| * @param col the column over which the filter is built | ||
| * @param expectedNumItems expected number of items which will be put into the filter. | ||
| * @param fpp expected false positive probability of the filter. | ||
| * | ||
| * @since 2.0.0 | ||
| */ | ||
| def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { | ||
| buildBloomFilter(col, BloomFilter.create(expectedNumItems, fpp)) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Bloom filter over a specified column. | ||
| * | ||
| * @param colName name of the column over which the filter is built | ||
| * @param expectedNumItems expected number of items which will be put into the filter. | ||
| * @param numBits expected number of bits of the filter. | ||
| * | ||
| * @since 2.0.0 | ||
| */ | ||
| def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { | ||
| buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, numBits)) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Bloom filter over a specified column. | ||
| * | ||
| * @param col the column over which the filter is built | ||
| * @param expectedNumItems expected number of items which will be put into the filter. | ||
| * @param numBits expected number of bits of the filter. | ||
| * | ||
| * @since 2.0.0 | ||
| */ | ||
| def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { | ||
| buildBloomFilter(col, BloomFilter.create(expectedNumItems, numBits)) | ||
| } | ||
|
|
||
| private def buildBloomFilter(col: Column, zero: BloomFilter): BloomFilter = { | ||
| val singleCol = df.select(col) | ||
| val colType = singleCol.schema.head.dataType | ||
|
|
||
| require(colType == StringType || colType.isInstanceOf[IntegralType], | ||
| s"Bloom filter only supports string type and integral types, but got $colType.") | ||
|
|
||
| val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { | ||
| (filter, row) => | ||
| filter.putBinary(row.getUTF8String(0).getBytes) | ||
| filter | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add comment to explain the branching at here? |
||
| } else { | ||
| (filter, row) => | ||
| // TODO: specialize it. | ||
| filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue()) | ||
| filter | ||
| } | ||
|
|
||
| singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
specific -> specialized
version -> variant