Skip to content

Commit

Permalink
PARQUET-116: Adding type safety for the filter object to be passed to…
Browse files Browse the repository at this point in the history
… user defined predicate
  • Loading branch information
Yash Datta authored and Yash Datta committed Oct 30, 2014
1 parent d5a2b9e commit f51a431
Show file tree
Hide file tree
Showing 15 changed files with 68 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ public static <T extends Comparable<T>, C extends Column<T> & SupportsLtGt> GtEq
/**
* Keeps records that pass the provided {@link UserDefinedPredicate}
*/
public static <T extends Comparable<T>, U extends UserDefinedPredicate<T>>
UserDefined<T, U> userDefined(Column<T> column, Class<U> clazz, Serializable o) {
return new UserDefined<T, U>(column, clazz, o);
public static <T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable>
UserDefined<T, U, S> userDefined(Column<T> column, Class<U> clazz, S o) {
return new UserDefined<T, U, S>(column, clazz, o);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package parquet.filter2.predicate;

import java.io.Serializable;

import parquet.filter2.predicate.Operators.And;
import parquet.filter2.predicate.Operators.Eq;
import parquet.filter2.predicate.Operators.Gt;
Expand Down Expand Up @@ -47,8 +49,8 @@ public static interface Visitor<R> {
R visit(And and);
R visit(Or or);
R visit(Not not);
<T extends Comparable<T>, U extends UserDefinedPredicate<T>> R visit(UserDefined<T, U> udp);
<T extends Comparable<T>, U extends UserDefinedPredicate<T>> R visit(LogicalNotUserDefined<T, U> udp);
<T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> R visit(UserDefined<T, U, S> udp);
<T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> R visit(LogicalNotUserDefined<T, U, S> udp);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package parquet.filter2.predicate;

import java.io.Serializable;

import parquet.filter2.predicate.FilterPredicate.Visitor;
import parquet.filter2.predicate.Operators.And;
import parquet.filter2.predicate.Operators.Eq;
Expand Down Expand Up @@ -84,12 +86,12 @@ public FilterPredicate visit(Not not) {
}

@Override
public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> FilterPredicate visit(UserDefined<T, U> udp) {
public <T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> FilterPredicate visit(UserDefined<T, U, S> udp) {
return udp;
}

@Override
public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> FilterPredicate visit(LogicalNotUserDefined<T, U> udp) {
public <T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> FilterPredicate visit(LogicalNotUserDefined<T, U, S> udp) {
return udp;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package parquet.filter2.predicate;

import java.io.Serializable;

import parquet.filter2.predicate.FilterPredicate.Visitor;
import parquet.filter2.predicate.Operators.And;
import parquet.filter2.predicate.Operators.Eq;
Expand Down Expand Up @@ -79,12 +81,12 @@ public FilterPredicate visit(Not not) {
}

@Override
public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> FilterPredicate visit(UserDefined<T, U> udp) {
return new LogicalNotUserDefined<T, U>(udp);
public <T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> FilterPredicate visit(UserDefined<T, U, S> udp) {
return new LogicalNotUserDefined<T, U, S>(udp);
}

@Override
public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> FilterPredicate visit(LogicalNotUserDefined<T, U> udp) {
public <T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> FilterPredicate visit(LogicalNotUserDefined<T, U, S> udp) {
return udp.getUserDefined();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,15 @@ public int hashCode() {
}
}

public static final class UserDefined<T extends Comparable<T>, U extends UserDefinedPredicate<T>> implements FilterPredicate, Serializable {
public static final class UserDefined<T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> implements FilterPredicate, Serializable {
private final Column<T> column;
private final Class<U> udpClass;
private final String toString;
private final Serializable o;
private final S o;
private static final String INSTANTIATION_ERROR_MESSAGE =
"Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor.";

UserDefined(Column<T> column, Class<U> udpClass, Serializable o) {
UserDefined(Column<T> column, Class<U> udpClass, S o) {
this.column = checkNotNull(column, "column");
this.udpClass = checkNotNull(udpClass, "udpClass");
String name = getClass().getSimpleName().toLowerCase();
Expand All @@ -367,7 +367,7 @@ public Class<U> getUserDefinedPredicateClass() {
return udpClass;
}

public Serializable getFilterObject() {
public S getFilterObject() {
return o;
}

Expand Down Expand Up @@ -415,16 +415,16 @@ public int hashCode() {

// Represents the inverse of a UserDefined. It is equivalent to not(userDefined), without the use
// of the not() operator
public static final class LogicalNotUserDefined <T extends Comparable<T>, U extends UserDefinedPredicate<T>> implements FilterPredicate, Serializable {
private final UserDefined<T, U> udp;
public static final class LogicalNotUserDefined <T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> implements FilterPredicate, Serializable {
private final UserDefined<T, U, S> udp;
private final String toString;

LogicalNotUserDefined(UserDefined<T, U> userDefined) {
LogicalNotUserDefined(UserDefined<T, U, S> userDefined) {
this.udp = checkNotNull(userDefined, "userDefined");
this.toString = "inverted(" + udp + ")";
}

public UserDefined<T, U> getUserDefined() {
public UserDefined<T, U, S> getUserDefined() {
return udp;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package parquet.filter2.predicate;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -128,13 +129,13 @@ public Void visit(Not not) {
}

@Override
public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> Void visit(UserDefined<T, U> udp) {
public <T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> Void visit(UserDefined<T, U, S> udp) {
validateColumn(udp.getColumn());
return null;
}

@Override
public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> Void visit(LogicalNotUserDefined<T, U> udp) {
public <T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> Void visit(LogicalNotUserDefined<T, U, S> udp) {
return udp.getUserDefined().accept(this);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
*/
// TODO: consider avoiding autoboxing and adding the specialized methods for each type
// TODO: downside is that's fairly unwieldy for users
public abstract class UserDefinedPredicate<T extends Comparable<T>> {
public abstract class UserDefinedPredicate<T extends Comparable<T>, S extends Serializable> {

/**
* A udp must have a default constructor.
Expand All @@ -26,7 +26,7 @@ public UserDefinedPredicate() { }
* Return true to keep the record with this value, false to drop it.
* o is a filter object that can be used for filtering the value.
*/
public abstract boolean keep(T value, Serializable o);
public abstract boolean keep(T value, S o);

/**
* Given information about a group of records (eg, the min and max value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import java.io.Serializable;

public class DummyUdp extends UserDefinedPredicate<Integer> {
public class DummyUdp extends UserDefinedPredicate<Integer, Serializable> {

@Override
public boolean keep(Integer value, Serializable o) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package parquet.filter2.predicate;

import java.io.Serializable;
import org.junit.Test;

import parquet.filter2.predicate.Operators.DoubleColumn;
Expand Down Expand Up @@ -40,7 +41,7 @@ public class TestLogicalInverseRewriter {
and(gt(doubleColumn, 12.0),
or(
or(eq(intColumn, 7), notEq(intColumn, 17)),
new LogicalNotUserDefined<Integer, DummyUdp>(userDefined(intColumn, DummyUdp.class, null)))),
new LogicalNotUserDefined<Integer, DummyUdp, Serializable>(userDefined(intColumn, DummyUdp.class, null)))),
or(gt(doubleColumn, 100.0), lt(intColumn, 77)));

private static void assertNoOp(FilterPredicate p) {
Expand All @@ -49,7 +50,7 @@ private static void assertNoOp(FilterPredicate p) {

@Test
public void testBaseCases() {
UserDefined<Integer, DummyUdp> ud = userDefined(intColumn, DummyUdp.class, null);
UserDefined<Integer, DummyUdp, Serializable> ud = userDefined(intColumn, DummyUdp.class, null);

assertNoOp(eq(intColumn, 17));
assertNoOp(notEq(intColumn, 17));
Expand All @@ -67,7 +68,7 @@ public void testBaseCases() {
assertEquals(gt(intColumn, 17), rewrite(not(ltEq(intColumn, 17))));
assertEquals(ltEq(intColumn, 17), rewrite(not(gt(intColumn, 17))));
assertEquals(lt(intColumn, 17), rewrite(not(gtEq(intColumn, 17))));
assertEquals(new LogicalNotUserDefined<Integer, DummyUdp>(ud), rewrite(not(ud)));
assertEquals(new LogicalNotUserDefined<Integer, DummyUdp, Serializable>(ud), rewrite(not(ud)));

FilterPredicate notedAnd = not(and(eq(intColumn, 17), eq(doubleColumn, 12.0)));
FilterPredicate distributedAnd = or(notEq(intColumn, 17), notEq(doubleColumn, 12.0));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package parquet.filter2.predicate;

import java.io.Serializable;
import org.junit.Test;

import parquet.filter2.predicate.Operators.DoubleColumn;
Expand All @@ -26,7 +27,7 @@ public class TestLogicalInverter {
private static final IntColumn intColumn = intColumn("a.b.c");
private static final DoubleColumn doubleColumn = doubleColumn("a.b.c");

private static final UserDefined<Integer, DummyUdp> ud = userDefined(intColumn, DummyUdp.class, null);
private static final UserDefined<Integer, DummyUdp, Serializable> ud = userDefined(intColumn, DummyUdp.class, null);

private static final FilterPredicate complex =
and(
Expand All @@ -41,7 +42,7 @@ public class TestLogicalInverter {
and(gt(doubleColumn, 12.0),
or(
or(eq(intColumn, 7), notEq(intColumn, 17)),
new LogicalNotUserDefined<Integer, DummyUdp>(userDefined(intColumn, DummyUdp.class, null)))),
new LogicalNotUserDefined<Integer, DummyUdp, Serializable>(userDefined(intColumn, DummyUdp.class, null)))),
and(ltEq(doubleColumn, 100.0), eq(intColumn, 77)));

@Test
Expand All @@ -63,10 +64,10 @@ public void testBaseCases() {

assertEquals(eq(intColumn, 17), invert(not(eq(intColumn, 17))));

UserDefined<Integer, DummyUdp> ud = userDefined(intColumn, DummyUdp.class, null);
assertEquals(new LogicalNotUserDefined<Integer, DummyUdp>(ud), invert(ud));
UserDefined<Integer, DummyUdp, Serializable> ud = userDefined(intColumn, DummyUdp.class, null);
assertEquals(new LogicalNotUserDefined<Integer, DummyUdp, Serializable>(ud), invert(ud));
assertEquals(ud, invert(not(ud)));
assertEquals(ud, invert(new LogicalNotUserDefined<Integer, DummyUdp>(ud)));
assertEquals(ud, invert(new LogicalNotUserDefined<Integer, DummyUdp, Serializable>(ud)));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package parquet.filter2.predicate;

import java.io.Serializable;
import org.junit.Test;

import java.io.Serializable;
Expand Down Expand Up @@ -50,7 +51,7 @@ public class TestSchemaCompatibilityValidator {
userDefined(intBar, DummyUdp.class, null))),
or(gt(stringC, Binary.fromString("bar")), notEq(stringC, Binary.fromString("baz"))));

static class LongDummyUdp extends UserDefinedPredicate<Long> {
static class LongDummyUdp extends UserDefinedPredicate<Long, Serializable> {
@Override
public boolean keep(Long value, Serializable o) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,16 @@ public void run() throws IOException {
addVisitEnd();

add(" @Override\n" +
" public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> IncrementallyUpdatedFilterPredicate visit(UserDefined<T, U> pred) {\n");
" public <T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> IncrementallyUpdatedFilterPredicate visit(UserDefined<T, U, S> pred) {\n");
addUdpBegin();
for (TypeInfo info : TYPES) {
addUdpCase(info, false);
}
addVisitEnd();

add(" @Override\n" +
" public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> IncrementallyUpdatedFilterPredicate visit(LogicalNotUserDefined<T, U> notPred) {\n" +
" UserDefined<T, U> pred = notPred.getUserDefined();\n");
" public <T extends Comparable<T>, U extends UserDefinedPredicate<T, S>, S extends Serializable> IncrementallyUpdatedFilterPredicate visit(LogicalNotUserDefined<T, U, S> notPred) {\n" +
" UserDefined<T, U, S> pred = notPred.getUserDefined();\n");
addUdpBegin();
for (TypeInfo info : TYPES) {
addUdpCase(info, true);
Expand Down Expand Up @@ -223,7 +223,7 @@ private void addUdpBegin() throws IOException {
"\n" +
" final U udp = pred.getUserDefinedPredicate();\n" +
"\n" +
" final Serializable o = pred.getFilterObject();\n" +
" final S o = pred.getFilterObject();\n" +
"\n");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public boolean keep(User u) {
});
}

public static class StartWithP extends UserDefinedPredicate<Binary> {
public static class StartWithP extends UserDefinedPredicate<Binary, Serializable> {

@Override
public boolean keep(Binary value, Serializable o) {
Expand All @@ -166,16 +166,15 @@ public boolean inverseCanDrop(Statistics<Binary> statistics) {
}
}

public static class SetInFilter extends UserDefinedPredicate<Long> {
public static class SetInFilter extends UserDefinedPredicate<Long, HashSet<Long>> {

@Override
public boolean keep(Long value, Serializable o) {
public boolean keep(Long value, HashSet o) {
if (value == null) {
return false;
}

Set<Long> hSet = (HashSet<Long>) o;
return hSet.contains(value);
return o.contains(value);
}

@Override
Expand Down
25 changes: 13 additions & 12 deletions parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package parquet.filter2.dsl

import java.io.Serializable
import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong}

import parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators, UserDefinedPredicate}
Expand All @@ -26,10 +27,10 @@ import parquet.io.api.Binary
*/
object Dsl {

private[Dsl] trait Column[T <: Comparable[T], C <: Operators.Column[T]] {
private[Dsl] trait Column[T <: Comparable[T], C <: Operators.Column[T], S <: java.io.Serializable] {
val javaColumn: C

def filterBy[U <: UserDefinedPredicate[T]](clazz: Class[U]) = FilterApi.userDefined(javaColumn, clazz, null)
def filterBy[U <: UserDefinedPredicate[T, S]](clazz: Class[U], o: S) = FilterApi.userDefined(javaColumn, clazz, o)

// this is not supported because it allows for easy mistakes. For example:
// val pred = IntColumn("foo") == "hello"
Expand All @@ -38,40 +39,40 @@ object Dsl {
throw new UnsupportedOperationException("You probably meant to use === or !==")
}

case class IntColumn(columnPath: String) extends Column[JInt, Operators.IntColumn] {
case class IntColumn(columnPath: String) extends Column[JInt, Operators.IntColumn, Serializable] {
override val javaColumn = FilterApi.intColumn(columnPath)
}

case class LongColumn(columnPath: String) extends Column[JLong, Operators.LongColumn] {
case class LongColumn(columnPath: String) extends Column[JLong, Operators.LongColumn, Serializable] {
override val javaColumn = FilterApi.longColumn(columnPath)
}

case class FloatColumn(columnPath: String) extends Column[JFloat, Operators.FloatColumn] {
case class FloatColumn(columnPath: String) extends Column[JFloat, Operators.FloatColumn, Serializable] {
override val javaColumn = FilterApi.floatColumn(columnPath)
}

case class DoubleColumn(columnPath: String) extends Column[JDouble, Operators.DoubleColumn] {
case class DoubleColumn(columnPath: String) extends Column[JDouble, Operators.DoubleColumn, Serializable] {
override val javaColumn = FilterApi.doubleColumn(columnPath)
}

case class BooleanColumn(columnPath: String) extends Column[JBoolean, Operators.BooleanColumn] {
case class BooleanColumn(columnPath: String) extends Column[JBoolean, Operators.BooleanColumn, Serializable] {
override val javaColumn = FilterApi.booleanColumn(columnPath)
}

case class BinaryColumn(columnPath: String) extends Column[Binary, Operators.BinaryColumn] {
case class BinaryColumn(columnPath: String) extends Column[Binary, Operators.BinaryColumn, Serializable] {
override val javaColumn = FilterApi.binaryColumn(columnPath)
}

implicit def enrichEqNotEq[T <: Comparable[T], C <: Operators.Column[T] with Operators.SupportsEqNotEq](column: Column[T, C]): SupportsEqNotEq[T,C] = new SupportsEqNotEq(column)
implicit def enrichEqNotEq[T <: Comparable[T], C <: Operators.Column[T] with Operators.SupportsEqNotEq, S <: Serializable](column: Column[T, C, S]): SupportsEqNotEq[T,C, S] = new SupportsEqNotEq(column)

class SupportsEqNotEq[T <: Comparable[T], C <: Operators.Column[T] with Operators.SupportsEqNotEq](val column: Column[T, C]) {
class SupportsEqNotEq[T <: Comparable[T], C <: Operators.Column[T] with Operators.SupportsEqNotEq, S <: Serializable](val column: Column[T, C, S]) {
def ===(v: T) = FilterApi.eq(column.javaColumn, v)
def !== (v: T) = FilterApi.notEq(column.javaColumn, v)
}

implicit def enrichLtGt[T <: Comparable[T], C <: Operators.Column[T] with Operators.SupportsLtGt](column: Column[T, C]): SupportsLtGt[T,C] = new SupportsLtGt(column)
implicit def enrichLtGt[T <: Comparable[T], C <: Operators.Column[T] with Operators.SupportsLtGt, S <: Serializable](column: Column[T, C, S]): SupportsLtGt[T,C, S] = new SupportsLtGt(column)

class SupportsLtGt[T <: Comparable[T], C <: Operators.Column[T] with Operators.SupportsLtGt](val column: Column[T, C]) {
class SupportsLtGt[T <: Comparable[T], C <: Operators.Column[T] with Operators.SupportsLtGt, S <: Serializable](val column: Column[T, C, S]) {
def >(v: T) = FilterApi.gt(column.javaColumn, v)
def >=(v: T) = FilterApi.gtEq(column.javaColumn, v)
def <(v: T) = FilterApi.lt(column.javaColumn, v)
Expand Down
Loading

0 comments on commit f51a431

Please sign in to comment.