Skip to content

Commit

Permalink
PARQUET-116: Pass a filter object to user defined predicate in filter…
Browse files Browse the repository at this point in the history
…2 api

Currently for creating a user defined predicate using the new filter api, no value can be passed to create a dynamic filter at runtime. This reduces the usefulness of the user defined predicate, and meaningful predicates cannot be created. We can add a generic Object value that is passed through the api, which can internally be used in the keep function of the user defined predicate for creating many different types of filters.
For example, in spark sql, we can pass in a list of filter values for a where IN clause query and filter the row values based on that list.

Author: Yash Datta <[email protected]>
Author: Alex Levenson <[email protected]>
Author: Yash Datta <[email protected]>

Closes apache#73 from saucam/master and squashes the following commits:

7231a3b [Yash Datta] Merge pull request #3 from isnotinvain/alexlevenson/fix-binary-compat
dcc276b [Alex Levenson] Ignore binary incompatibility in private filter2 class
7bfa5ad [Yash Datta] Merge pull request #2 from isnotinvain/alexlevenson/simplify-udp-state
0187376 [Alex Levenson] Resolve merge conflicts
25aa716 [Alex Levenson] Simplify user defined predicates with state
51952f8 [Yash Datta] PARQUET-116: Fix whitespace
d7b7159 [Yash Datta] PARQUET-116: Make UserDefined abstract, add two subclasses, one accepting udp class, other accepting serializable udp instance
40d394a [Yash Datta] PARQUET-116: Fix whitespace
9a63611 [Yash Datta] PARQUET-116: Fix whitespace
7caa4dc [Yash Datta] PARQUET-116: Add ConfiguredUserDefined that takes a serialiazble udp directly
0eaabf4 [Yash Datta] PARQUET-116: Move the config object from keep method to a configure method in udp predicate
f51a431 [Yash Datta] PARQUET-116: Adding type safety for the filter object to be passed to user defined predicate
d5a2b9e [Yash Datta] PARQUET-116: Enforce that the filter object to be passed must be Serializable
dfd0478 [Yash Datta] PARQUET-116: Add a test case for passing a filter object to user defined predicate
4ab46ec [Yash Datta] PARQUET-116: Pass a filter object to user defined predicate in filter2 api
  • Loading branch information
Yash Datta authored and rdblue committed Mar 7, 2015
1 parent 7970b87 commit 82f993e
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
package parquet.filter2.predicate;

import java.io.Serializable;

import parquet.common.schema.ColumnPath;
import parquet.filter2.predicate.Operators.And;
import parquet.filter2.predicate.Operators.BinaryColumn;
Expand All @@ -38,6 +40,8 @@
import parquet.filter2.predicate.Operators.SupportsEqNotEq;
import parquet.filter2.predicate.Operators.SupportsLtGt;
import parquet.filter2.predicate.Operators.UserDefined;
import parquet.filter2.predicate.Operators.UserDefinedByClass;
import parquet.filter2.predicate.Operators.UserDefinedByInstance;

/**
* The Filter API is expressed through these static methods.
Expand Down Expand Up @@ -162,10 +166,23 @@ public static <T extends Comparable<T>, C extends Column<T> & SupportsLtGt> GtEq

/**
* Keeps records that pass the provided {@link UserDefinedPredicate}
*
* The provided class must have a default constructor. To use an instance
* of a UserDefinedPredicate instead, see {@link #userDefined(column, udp)} below.
*/
public static <T extends Comparable<T>, U extends UserDefinedPredicate<T>>
UserDefined<T, U> userDefined(Column<T> column, Class<U> clazz) {
return new UserDefined<T, U>(column, clazz);
return new UserDefinedByClass<T, U>(column, clazz);
}

/**
* Keeps records that pass the provided {@link UserDefinedPredicate}
*
* The provided instance of UserDefinedPredicate must be serializable.
*/
public static <T extends Comparable<T>, U extends UserDefinedPredicate<T> & Serializable>
UserDefined<T, U> userDefined(Column<T> column, U udp) {
return new UserDefinedByInstance<T, U>(column, udp);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,33 @@ public int hashCode() {
}
}

public static final class UserDefined<T extends Comparable<T>, U extends UserDefinedPredicate<T>> implements FilterPredicate, Serializable {
private final Column<T> column;
public static abstract class UserDefined<T extends Comparable<T>, U extends UserDefinedPredicate<T>> implements FilterPredicate, Serializable {
protected final Column<T> column;

UserDefined(Column<T> column) {
this.column = checkNotNull(column, "column");
}

public Column<T> getColumn() {
return column;
}

public abstract U getUserDefinedPredicate();

@Override
public <R> R accept(Visitor<R> visitor) {
return visitor.visit(this);
}
}

public static final class UserDefinedByClass<T extends Comparable<T>, U extends UserDefinedPredicate<T>> extends UserDefined<T, U> {
private final Class<U> udpClass;
private final String toString;
private static final String INSTANTIATION_ERROR_MESSAGE =
"Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor.";

UserDefined(Column<T> column, Class<U> udpClass) {
this.column = checkNotNull(column, "column");
UserDefinedByClass(Column<T> column, Class<U> udpClass) {
super(column);
this.udpClass = checkNotNull(udpClass, "udpClass");
String name = getClass().getSimpleName().toLowerCase();
this.toString = name + "(" + column.getColumnPath().toDotString() + ", " + udpClass.getName() + ")";
Expand All @@ -375,14 +393,11 @@ public static final class UserDefined<T extends Comparable<T>, U extends UserDef
getUserDefinedPredicate();
}

public Column<T> getColumn() {
return column;
}

public Class<U> getUserDefinedPredicateClass() {
return udpClass;
}

@Override
public U getUserDefinedPredicate() {
try {
return udpClass.newInstance();
Expand All @@ -394,8 +409,46 @@ public U getUserDefinedPredicate() {
}

@Override
public <R> R accept(Visitor<R> visitor) {
return visitor.visit(this);
public String toString() {
return toString;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

UserDefinedByClass that = (UserDefinedByClass) o;

if (!column.equals(that.column)) return false;
if (!udpClass.equals(that.udpClass)) return false;

return true;
}

@Override
public int hashCode() {
int result = column.hashCode();
result = 31 * result + udpClass.hashCode();
result = result * 31 + getClass().hashCode();
return result;
}
}

public static final class UserDefinedByInstance<T extends Comparable<T>, U extends UserDefinedPredicate<T> & Serializable> extends UserDefined<T, U> {
private final String toString;
private final U udpInstance;

UserDefinedByInstance(Column<T> column, U udpInstance) {
super(column);
this.udpInstance = checkNotNull(udpInstance, "udpInstance");
String name = getClass().getSimpleName().toLowerCase();
this.toString = name + "(" + column.getColumnPath().toDotString() + ", " + udpInstance + ")";
}

@Override
public U getUserDefinedPredicate() {
return udpInstance;
}

@Override
Expand All @@ -408,18 +461,18 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

UserDefined that = (UserDefined) o;
UserDefinedByInstance that = (UserDefinedByInstance) o;

if (!column.equals(that.column)) return false;
if (!udpClass.equals(that.udpClass)) return false;
if (!udpInstance.equals(that.udpInstance)) return false;

return true;
}

@Override
public int hashCode() {
int result = column.hashCode();
result = 31 * result + udpClass.hashCode();
result = 31 * result + udpInstance.hashCode();
result = result * 31 + getClass().hashCode();
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ public UserDefinedPredicate() { }
* }
*/
public abstract boolean inverseCanDrop(Statistics<T> statistics);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

import org.junit.Test;

Expand All @@ -32,9 +33,11 @@
import parquet.filter2.predicate.Operators.Eq;
import parquet.filter2.predicate.Operators.Gt;
import parquet.filter2.predicate.Operators.IntColumn;
import parquet.filter2.predicate.Operators.LongColumn;
import parquet.filter2.predicate.Operators.Not;
import parquet.filter2.predicate.Operators.Or;
import parquet.filter2.predicate.Operators.UserDefined;
import parquet.filter2.predicate.Operators.UserDefinedByClass;
import parquet.io.api.Binary;

import static org.junit.Assert.assertEquals;
Expand All @@ -45,6 +48,7 @@
import static parquet.filter2.predicate.FilterApi.eq;
import static parquet.filter2.predicate.FilterApi.gt;
import static parquet.filter2.predicate.FilterApi.intColumn;
import static parquet.filter2.predicate.FilterApi.longColumn;
import static parquet.filter2.predicate.FilterApi.not;
import static parquet.filter2.predicate.FilterApi.notEq;
import static parquet.filter2.predicate.FilterApi.or;
Expand All @@ -54,6 +58,7 @@
public class TestFilterApiMethods {

private static final IntColumn intColumn = intColumn("a.b.c");
private static final LongColumn longColumn = longColumn("a.b.l");
private static final DoubleColumn doubleColumn = doubleColumn("x.y.z");
private static final BinaryColumn binColumn = binaryColumn("a.string.column");

Expand Down Expand Up @@ -100,15 +105,15 @@ public void testUdp() {
FilterPredicate predicate = or(eq(doubleColumn, 12.0), userDefined(intColumn, DummyUdp.class));
assertTrue(predicate instanceof Or);
FilterPredicate ud = ((Or) predicate).getRight();
assertTrue(ud instanceof UserDefined);
assertEquals(DummyUdp.class, ((UserDefined) ud).getUserDefinedPredicateClass());
assertTrue(ud instanceof UserDefinedByClass);
assertEquals(DummyUdp.class, ((UserDefinedByClass) ud).getUserDefinedPredicateClass());
assertTrue(((UserDefined) ud).getUserDefinedPredicate() instanceof DummyUdp);
}

@Test
public void testSerializable() throws Exception {
public void testSerializable() throws Exception {
BinaryColumn binary = binaryColumn("foo");
FilterPredicate p = or(and(userDefined(intColumn, DummyUdp.class), predicate), eq(binary, Binary.fromString("hi")));
FilterPredicate p = and(or(and(userDefined(intColumn, DummyUdp.class), predicate), eq(binary, Binary.fromString("hi"))), userDefined(longColumn, new IsMultipleOf(7)));
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(p);
Expand All @@ -118,4 +123,50 @@ public void testSerializable() throws Exception {
FilterPredicate read = (FilterPredicate) is.readObject();
assertEquals(p, read);
}

public static class IsMultipleOf extends UserDefinedPredicate<Long> implements Serializable {

private long of;

public IsMultipleOf(long of) {
this.of = of;
}

@Override
public boolean keep(Long value) {
if (value == null) {
return false;
}
return value % of == 0;
}

@Override
public boolean canDrop(Statistics<Long> statistics) {
return false;
}

@Override
public boolean inverseCanDrop(Statistics<Long> statistics) {
return false;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

IsMultipleOf that = (IsMultipleOf) o;
return this.of == that.of;
}

@Override
public int hashCode() {
return new Long(of).hashCode();
}

@Override
public String toString() {
return "IsMultipleOf(" + of + ")";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.HashSet;

import org.junit.BeforeClass;
import org.junit.Test;
Expand All @@ -33,6 +36,7 @@
import parquet.filter2.predicate.FilterPredicate;
import parquet.filter2.predicate.Operators.BinaryColumn;
import parquet.filter2.predicate.Operators.DoubleColumn;
import parquet.filter2.predicate.Operators.LongColumn;
import parquet.filter2.predicate.Statistics;
import parquet.filter2.predicate.UserDefinedPredicate;
import parquet.filter2.recordlevel.PhoneBookWriter.Location;
Expand All @@ -44,6 +48,7 @@
import static parquet.filter2.predicate.FilterApi.and;
import static parquet.filter2.predicate.FilterApi.binaryColumn;
import static parquet.filter2.predicate.FilterApi.doubleColumn;
import static parquet.filter2.predicate.FilterApi.longColumn;
import static parquet.filter2.predicate.FilterApi.eq;
import static parquet.filter2.predicate.FilterApi.gt;
import static parquet.filter2.predicate.FilterApi.not;
Expand Down Expand Up @@ -178,6 +183,34 @@ public boolean inverseCanDrop(Statistics<Binary> statistics) {
return false;
}
}

public static class SetInFilter extends UserDefinedPredicate<Long> implements Serializable {

private HashSet<Long> hSet;

public SetInFilter(HashSet<Long> phSet) {
hSet = phSet;
}

@Override
public boolean keep(Long value) {
if (value == null) {
return false;
}

return hSet.contains(value);
}

@Override
public boolean canDrop(Statistics<Long> statistics) {
return false;
}

@Override
public boolean inverseCanDrop(Statistics<Long> statistics) {
return false;
}
}

@Test
public void testNameNotStartWithP() throws Exception {
Expand All @@ -194,6 +227,27 @@ public boolean keep(User u) {
}
});
}

@Test
public void testUserDefinedByInstance() throws Exception {
LongColumn name = longColumn("id");

final HashSet<Long> h = new HashSet<Long>();
h.add(20L);
h.add(27L);
h.add(28L);

FilterPredicate pred = userDefined(name, new SetInFilter(h));

List<Group> found = PhoneBookWriter.readFile(phonebookFile, FilterCompat.get(pred));

assertFilter(found, new UserFilter() {
@Override
public boolean keep(User u) {
return u != null && h.contains(u.getId());
}
});
}

@Test
public void testComplex() throws Exception {
Expand Down
3 changes: 3 additions & 0 deletions parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package parquet.filter2.dsl

import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong}
import java.io.Serializable

import parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators, UserDefinedPredicate}
import parquet.io.api.Binary
Expand Down Expand Up @@ -48,6 +49,8 @@ object Dsl {
val javaColumn: C

def filterBy[U <: UserDefinedPredicate[T]](clazz: Class[U]) = FilterApi.userDefined(javaColumn, clazz)

def filterBy[U <: UserDefinedPredicate[T] with Serializable](udp: U) = FilterApi.userDefined(javaColumn, udp)

// this is not supported because it allows for easy mistakes. For example:
// val pred = IntColumn("foo") == "hello"
Expand Down
Loading

0 comments on commit 82f993e

Please sign in to comment.