Skip to content

Commit

Permalink
Merge pull request #2 from isnotinvain/alexlevenson/simplify-udp-state
Browse files Browse the repository at this point in the history
Simplify user defined predicates with state, Add more test cases
  • Loading branch information
saucam committed Feb 4, 2015
2 parents 51952f8 + 0187376 commit 7bfa5ad
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import parquet.filter2.predicate.Operators.BinaryColumn;
import parquet.filter2.predicate.Operators.BooleanColumn;
import parquet.filter2.predicate.Operators.Column;
import parquet.filter2.predicate.Operators.ConfiguredUserDefined;
import parquet.filter2.predicate.Operators.DoubleColumn;
import parquet.filter2.predicate.Operators.Eq;
import parquet.filter2.predicate.Operators.FloatColumn;
Expand All @@ -22,8 +21,9 @@
import parquet.filter2.predicate.Operators.Or;
import parquet.filter2.predicate.Operators.SupportsEqNotEq;
import parquet.filter2.predicate.Operators.SupportsLtGt;
import parquet.filter2.predicate.Operators.SimpleUserDefined;
import parquet.filter2.predicate.Operators.UserDefined;
import parquet.filter2.predicate.Operators.UserDefinedByClass;
import parquet.filter2.predicate.Operators.UserDefinedByInstance;

/**
* The Filter API is expressed through these static methods.
Expand Down Expand Up @@ -148,18 +148,23 @@ public static <T extends Comparable<T>, C extends Column<T> & SupportsLtGt> GtEq

/**
* Keeps records that pass the provided {@link UserDefinedPredicate}
*
* The provided class must have a default constructor. To use an instance
* of a UserDefinedPredicate instead, see {@link #userDefined(column, udp)} below.
*/
public static <T extends Comparable<T>, U extends UserDefinedPredicate<T>>
UserDefined<T, U> userDefined(Column<T> column, Class<U> clazz) {
return new SimpleUserDefined<T, U>(column, clazz);
return new UserDefinedByClass<T, U>(column, clazz);
}

/**
* Similar to above but allows to pass Serializable {@link UserDefinedPredicate}
* Keeps records that pass the provided {@link UserDefinedPredicate}
*
* The provided instance of UserDefinedPredicate must be serializable.
*/
public static <T extends Comparable<T>, U extends UserDefinedPredicate<T> & Serializable>
UserDefined<T, U> userDefined(Column<T> column, U udp) {
return new ConfiguredUserDefined<T, U> (column, udp);
return new UserDefinedByInstance<T, U>(column, udp);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,9 @@ public int hashCode() {

public static abstract class UserDefined<T extends Comparable<T>, U extends UserDefinedPredicate<T>> implements FilterPredicate, Serializable {
protected final Column<T> column;
protected String toString;
private static final String INSTANTIATION_ERROR_MESSAGE =
"Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor.";

UserDefined(Column<T> column) {
this.column = checkNotNull(column, "column");
String name = getClass().getSimpleName().toLowerCase();
this.toString = name + "(" + column.getColumnPath().toDotString() + ", UserDefined)";
}

public Column<T> getColumn() {
Expand All @@ -359,39 +354,18 @@ public Column<T> getColumn() {
public abstract U getUserDefinedPredicate();

@Override
public abstract <R> R accept(Visitor<R> visitor);

@Override
public String toString() {
return toString;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

UserDefined that = (UserDefined) o;

if (!column.equals(that.column)) return false;

return true;
}

@Override
public int hashCode() {
int result = column.hashCode();
result = result * 31 + getClass().hashCode();
return result;
public <R> R accept(Visitor<R> visitor) {
return visitor.visit(this);
}
}

public static final class SimpleUserDefined<T extends Comparable<T>, U extends UserDefinedPredicate<T>> extends UserDefined<T, U> {
public static final class UserDefinedByClass<T extends Comparable<T>, U extends UserDefinedPredicate<T>> extends UserDefined<T, U> {
private final Class<U> udpClass;
private final String toString;
private static final String INSTANTIATION_ERROR_MESSAGE =
"Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor.";
"Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor.";

SimpleUserDefined(Column<T> column, Class<U> udpClass) {
UserDefinedByClass(Column<T> column, Class<U> udpClass) {
super(column);
this.udpClass = checkNotNull(udpClass, "udpClass");
String name = getClass().getSimpleName().toLowerCase();
Expand All @@ -401,14 +375,11 @@ public static final class SimpleUserDefined<T extends Comparable<T>, U extends U
getUserDefinedPredicate();
}

public Column<T> getColumn() {
return column;
}

public Class<U> getUserDefinedPredicateClass() {
return udpClass;
}

@Override
public U getUserDefinedPredicate() {
try {
return udpClass.newInstance();
Expand All @@ -419,11 +390,6 @@ public U getUserDefinedPredicate() {
}
}

@Override
public <R> R accept(Visitor<R> visitor) {
return visitor.visit(this);
}

@Override
public String toString() {
return toString;
Expand All @@ -434,7 +400,7 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

SimpleUserDefined that = (SimpleUserDefined) o;
UserDefinedByClass that = (UserDefinedByClass) o;

if (!column.equals(that.column)) return false;
if (!udpClass.equals(that.udpClass)) return false;
Expand All @@ -450,28 +416,21 @@ public int hashCode() {
return result;
}
}

public static final class ConfiguredUserDefined<T extends Comparable<T>, U extends UserDefinedPredicate<T> & Serializable > extends UserDefined<T, U> {
//private final Column<T> column;
private final U udp;

public static final class UserDefinedByInstance<T extends Comparable<T>, U extends UserDefinedPredicate<T> & Serializable> extends UserDefined<T, U> {
private final String toString;
private final U udpInstance;

ConfiguredUserDefined(Column<T> column, U udp) {
//column = checkNotNull(column, "column");
UserDefinedByInstance(Column<T> column, U udpInstance) {
super(column);
this.udp = checkNotNull(udp, "udp");
this.udpInstance = checkNotNull(udpInstance, "udpInstance");
String name = getClass().getSimpleName().toLowerCase();
this.toString = name + "(" + column.getColumnPath().toDotString() + ", " + udp.getClass().getName() + ")";
this.toString = name + "(" + column.getColumnPath().toDotString() + ", " + udpInstance + ")";
}

@Override
public U getUserDefinedPredicate() {
return udp;
}

@Override
public <R> R accept(Visitor<R> visitor) {
return visitor.visit(this);
return udpInstance;
}

@Override
Expand All @@ -484,18 +443,18 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

ConfiguredUserDefined that = (ConfiguredUserDefined) o;
UserDefinedByInstance that = (UserDefinedByInstance) o;

if (!column.equals(that.column)) return false;
if (!udp.equals(that.udp)) return false;
if (!udpInstance.equals(that.udpInstance)) return false;

return true;
}

@Override
public int hashCode() {
int result = column.hashCode();
result = 31 * result + udp.hashCode();
result = 31 * result + udpInstance.hashCode();
result = result * 31 + getClass().hashCode();
return result;
}
Expand Down Expand Up @@ -545,4 +504,5 @@ public int hashCode() {
return result;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

import org.junit.Test;

Expand All @@ -14,10 +15,11 @@
import parquet.filter2.predicate.Operators.Eq;
import parquet.filter2.predicate.Operators.Gt;
import parquet.filter2.predicate.Operators.IntColumn;
import parquet.filter2.predicate.Operators.LongColumn;
import parquet.filter2.predicate.Operators.Not;
import parquet.filter2.predicate.Operators.Or;
import parquet.filter2.predicate.Operators.SimpleUserDefined;
import parquet.filter2.predicate.Operators.UserDefined;
import parquet.filter2.predicate.Operators.UserDefinedByClass;
import parquet.io.api.Binary;

import static org.junit.Assert.assertEquals;
Expand All @@ -28,6 +30,7 @@
import static parquet.filter2.predicate.FilterApi.eq;
import static parquet.filter2.predicate.FilterApi.gt;
import static parquet.filter2.predicate.FilterApi.intColumn;
import static parquet.filter2.predicate.FilterApi.longColumn;
import static parquet.filter2.predicate.FilterApi.not;
import static parquet.filter2.predicate.FilterApi.notEq;
import static parquet.filter2.predicate.FilterApi.or;
Expand All @@ -37,6 +40,7 @@
public class TestFilterApiMethods {

private static final IntColumn intColumn = intColumn("a.b.c");
private static final LongColumn longColumn = longColumn("a.b.l");
private static final DoubleColumn doubleColumn = doubleColumn("x.y.z");
private static final BinaryColumn binColumn = binaryColumn("a.string.column");

Expand Down Expand Up @@ -83,15 +87,15 @@ public void testUdp() {
FilterPredicate predicate = or(eq(doubleColumn, 12.0), userDefined(intColumn, DummyUdp.class));
assertTrue(predicate instanceof Or);
FilterPredicate ud = ((Or) predicate).getRight();
assertTrue(ud instanceof SimpleUserDefined);
assertEquals(DummyUdp.class, ((SimpleUserDefined) ud).getUserDefinedPredicateClass());
assertTrue(ud instanceof UserDefinedByClass);
assertEquals(DummyUdp.class, ((UserDefinedByClass) ud).getUserDefinedPredicateClass());
assertTrue(((UserDefined) ud).getUserDefinedPredicate() instanceof DummyUdp);
}

@Test
public void testSerializable() throws Exception {
public void testSerializable() throws Exception {
BinaryColumn binary = binaryColumn("foo");
FilterPredicate p = or(and(userDefined(intColumn, DummyUdp.class), predicate), eq(binary, Binary.fromString("hi")));
FilterPredicate p = and(or(and(userDefined(intColumn, DummyUdp.class), predicate), eq(binary, Binary.fromString("hi"))), userDefined(longColumn, new IsMultipleOf(7)));
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(p);
Expand All @@ -101,4 +105,50 @@ public void testSerializable() throws Exception {
FilterPredicate read = (FilterPredicate) is.readObject();
assertEquals(p, read);
}

public static class IsMultipleOf extends UserDefinedPredicate<Long> implements Serializable {

private long of;

public IsMultipleOf(long of) {
this.of = of;
}

@Override
public boolean keep(Long value) {
if (value == null) {
return false;
}
return value % of == 0;
}

@Override
public boolean canDrop(Statistics<Long> statistics) {
return false;
}

@Override
public boolean inverseCanDrop(Statistics<Long> statistics) {
return false;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

IsMultipleOf that = (IsMultipleOf) o;
return this.of == that.of;
}

@Override
public int hashCode() {
return new Long(of).hashCode();
}

@Override
public String toString() {
return "IsMultipleOf(" + of + ")";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public boolean inverseCanDrop(Statistics<Binary> statistics) {

public static class SetInFilter extends UserDefinedPredicate<Long> implements Serializable {

HashSet<Long> hSet;
private HashSet<Long> hSet;

public SetInFilter(HashSet<Long> phSet) {
hSet = phSet;
Expand Down Expand Up @@ -211,23 +211,22 @@ public boolean keep(User u) {
}

@Test
public void testIdIn() throws Exception {
public void testUserDefinedByInstance() throws Exception {
LongColumn name = longColumn("id");

HashSet<Long> h = new HashSet<Long>() {{
add(20L); add(27L); add(28L);
}};
final HashSet<Long> h = new HashSet<Long>();
h.add(20L);
h.add(27L);
h.add(28L);

FilterPredicate pred = userDefined(name, new SetInFilter(h));

List<Group> found = PhoneBookWriter.readFile(phonebookFile, FilterCompat.get(pred));

assertFilter(found, new UserFilter() {
@Override
public boolean keep(User u) {
Set<Long> h = new HashSet<Long>() {{
add(20L); add(27L); add(28L);
}};
return h.contains(u.getId());
return u != null && h.contains(u.getId());
}
});
}
Expand Down
3 changes: 3 additions & 0 deletions parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package parquet.filter2.dsl

import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong}
import java.io.Serializable

import parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators, UserDefinedPredicate}
import parquet.io.api.Binary
Expand Down Expand Up @@ -30,6 +31,8 @@ object Dsl {
val javaColumn: C

def filterBy[U <: UserDefinedPredicate[T]](clazz: Class[U]) = FilterApi.userDefined(javaColumn, clazz)

def filterBy[U <: UserDefinedPredicate[T] with Serializable](udp: U) = FilterApi.userDefined(javaColumn, udp)

// this is not supported because it allows for easy mistakes. For example:
// val pred = IntColumn("foo") == "hello"
Expand Down
Loading

0 comments on commit 7bfa5ad

Please sign in to comment.