Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ public enum Version {
/**
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
* - Version number, always 1 (32 bit)
* - Number of hash functions (32 bit)
* - Total number of words of the underlying bit array (32 bit)
* - The words/longs (numWords * 64 bit)
* - Number of hash functions (32 bit)
*/
V1(1);

Expand Down Expand Up @@ -95,6 +95,16 @@ int getVersionNumber() {
*/
public abstract boolean put(Object item);

/**
* A specific version of {@link #put(Object)}, that can only be used to put byte array.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specific -> specialized

version -> variant

*/
public abstract boolean putBinary(byte[] bytes);

/**
* A specific version of {@link #put(Object)}, that can only be used to put long.
*/
public abstract boolean putLong(long l);

/**
* Determines whether a given bloom filter is compatible with this bloom filter. For two
* bloom filters to be compatible, they must have the same bit size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import java.io.*;

public class BloomFilterImpl extends BloomFilter {
public class BloomFilterImpl extends BloomFilter implements Serializable {

private final int numHashFunctions;
private final BitArray bits;
private int numHashFunctions;
private BitArray bits;

BloomFilterImpl(int numHashFunctions, long numBits) {
this(new BitArray(numBits), numHashFunctions);
Expand All @@ -33,6 +33,8 @@ private BloomFilterImpl(BitArray bits, int numHashFunctions) {
this.numHashFunctions = numHashFunctions;
}

private BloomFilterImpl() {}

@Override
public boolean equals(Object other) {
if (other == this) {
Expand Down Expand Up @@ -63,55 +65,90 @@ public long bitSize() {
return bits.bitSize();
}

private static long hashObjectToLong(Object item) {
if (item instanceof String) {
try {
byte[] bytes = ((String) item).getBytes("utf-8");
return hashBytesToLong(bytes);
} catch (UnsupportedEncodingException e) {
throw new RuntimeException("Only support utf-8 string", e);
}
private byte[] getBytesFromUTF8String(Object s) {
try {
return ((String) s).getBytes("utf-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException("Only support utf-8 string", e);
}
}

private long integralToLong(Object i) {
long longValue;

if (i instanceof Long) {
longValue = (Long) i;
} else if (i instanceof Integer) {
longValue = ((Integer) i).longValue();
} else if (i instanceof Short) {
longValue = ((Short) i).longValue();
} else if (i instanceof Byte) {
longValue = ((Byte) i).longValue();
} else {
long longValue;

if (item instanceof Long) {
longValue = (Long) item;
} else if (item instanceof Integer) {
longValue = ((Integer) item).longValue();
} else if (item instanceof Short) {
longValue = ((Short) item).longValue();
} else if (item instanceof Byte) {
longValue = ((Byte) item).longValue();
} else {
throw new IllegalArgumentException(
"Support for " + item.getClass().getName() + " not implemented"
);
}
throw new IllegalArgumentException(
"Support for " + i.getClass().getName() + " not implemented"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Unsupported data type " + i.getClass().getName()

);
}

int h1 = Murmur3_x86_32.hashLong(longValue, 0);
int h2 = Murmur3_x86_32.hashLong(longValue, h1);
return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);
return longValue;
}

@Override
public boolean put(Object item) {
if (item instanceof String) {
return putBinary(getBytesFromUTF8String(item));
} else {
return putLong(integralToLong(item));
}
}

private static long hashBytesToLong(byte[] bytes) {
@Override
public boolean putBinary(byte[] bytes) {
int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);
return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);

long bitSize = bits.bitSize();
boolean bitsChanged = false;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
if (combinedHash < 0) {
combinedHash = ~combinedHash;
}
bitsChanged |= bits.set(combinedHash % bitSize);
}
return bitsChanged;
}

@Override
public boolean put(Object item) {
private boolean mightContainBinary(byte[] bytes) {
int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);

long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
if (combinedHash < 0) {
combinedHash = ~combinedHash;
}
if (!bits.get(combinedHash % bitSize)) {
return false;
}
}
return true;
}

// Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash
// values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
// Note that `CountMinSketch` use a different strategy for long type, it hash the input long
// element with every i to produce n hash values.
long hash64 = hashObjectToLong(item);
int h1 = (int) (hash64 >> 32);
int h2 = (int) hash64;
@Override
public boolean putLong(long l) {
// Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n
// hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
// Note that `CountMinSketch` use a different strategy, it hash the input long element with
// every i to produce n hash values.
// TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here?
int h1 = Murmur3_x86_32.hashLong(l, 0);
int h2 = Murmur3_x86_32.hashLong(l, h1);

long bitSize = bits.bitSize();
boolean bitsChanged = false;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
Expand All @@ -124,13 +161,11 @@ public boolean put(Object item) {
return bitsChanged;
}

@Override
public boolean mightContain(Object item) {
long bitSize = bits.bitSize();
long hash64 = hashObjectToLong(item);
int h1 = (int) (hash64 >> 32);
int h2 = (int) hash64;
private boolean mightContainLong(long l) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should open this one as an api too

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and have a mightContain for string as well

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can you guys also fix count-min sketch to make them consistent?)

int h1 = Murmur3_x86_32.hashLong(l, 0);
int h2 = Murmur3_x86_32.hashLong(l, h1);

long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
Expand All @@ -144,6 +179,15 @@ public boolean mightContain(Object item) {
return true;
}

@Override
public boolean mightContain(Object item) {
if (item instanceof String) {
return mightContainBinary(getBytesFromUTF8String(item));
} else {
return mightContainLong(integralToLong(item));
}
}

@Override
public boolean isCompatible(BloomFilter other) {
if (other == null) {
Expand Down Expand Up @@ -191,18 +235,33 @@ public void writeTo(OutputStream out) throws IOException {
DataOutputStream dos = new DataOutputStream(out);

dos.writeInt(Version.V1.getVersionNumber());
bits.writeTo(dos);
dos.writeInt(numHashFunctions);
bits.writeTo(dos);
}

public static BloomFilterImpl readFrom(InputStream in) throws IOException {
private void readFrom0(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);

int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IOException("Unexpected Bloom filter version number (" + version + ")");
}

return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt());
this.numHashFunctions = dis.readInt();
this.bits = BitArray.readFrom(dis);
}

public static BloomFilterImpl readFrom(InputStream in) throws IOException {
BloomFilterImpl filter = new BloomFilterImpl();
filter.readFrom0(in);
return filter;
}

private void writeObject(ObjectOutputStream out) throws IOException {
writeTo(out);
}

private void readObject(ObjectInputStream in) throws IOException {
readFrom0(in);
}
}
5 changes: 5 additions & 0 deletions sql/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
<version>1.5.6</version>
<type>jar</type>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sketch_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ import java.{lang => jl, util => ju}
import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.types.{IntegralType, StringType}
import org.apache.spark.util.sketch.BloomFilter

/**
* :: Experimental ::
Expand Down Expand Up @@ -309,4 +312,77 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
}

/**
* Builds a Bloom filter over a specified column.
*
* @param colName name of the column over which the filter is built
* @param expectedNumItems expected number of items which will be put into the filter.
* @param fpp expected false positive probability of the filter.
*
* @since 2.0.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = {
buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, fpp))
}

/**
* Builds a Bloom filter over a specified column.
*
* @param col the column over which the filter is built
* @param expectedNumItems expected number of items which will be put into the filter.
* @param fpp expected false positive probability of the filter.
*
* @since 2.0.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = {
buildBloomFilter(col, BloomFilter.create(expectedNumItems, fpp))
}

/**
* Builds a Bloom filter over a specified column.
*
* @param colName name of the column over which the filter is built
* @param expectedNumItems expected number of items which will be put into the filter.
* @param numBits expected number of bits of the filter.
*
* @since 2.0.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = {
buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, numBits))
}

/**
* Builds a Bloom filter over a specified column.
*
* @param col the column over which the filter is built
* @param expectedNumItems expected number of items which will be put into the filter.
* @param numBits expected number of bits of the filter.
*
* @since 2.0.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = {
buildBloomFilter(col, BloomFilter.create(expectedNumItems, numBits))
}

private def buildBloomFilter(col: Column, zero: BloomFilter): BloomFilter = {
val singleCol = df.select(col)
val colType = singleCol.schema.head.dataType

require(colType == StringType || colType.isInstanceOf[IntegralType],
s"Bloom filter only supports string type and integral types, but got $colType.")

val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) {
(filter, row) =>
filter.putBinary(row.getUTF8String(0).getBytes)
filter

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add comment to explain the branching at here?

} else {
(filter, row) =>
// TODO: specialize it.
filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue())
filter
}

singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _)
}
}
Loading