diff --git a/parquet-column/src/main/java/parquet/filter2/predicate/FilterApi.java b/parquet-column/src/main/java/parquet/filter2/predicate/FilterApi.java index a3c5d08863..cb82798903 100644 --- a/parquet-column/src/main/java/parquet/filter2/predicate/FilterApi.java +++ b/parquet-column/src/main/java/parquet/filter2/predicate/FilterApi.java @@ -7,7 +7,6 @@ import parquet.filter2.predicate.Operators.BinaryColumn; import parquet.filter2.predicate.Operators.BooleanColumn; import parquet.filter2.predicate.Operators.Column; -import parquet.filter2.predicate.Operators.ConfiguredUserDefined; import parquet.filter2.predicate.Operators.DoubleColumn; import parquet.filter2.predicate.Operators.Eq; import parquet.filter2.predicate.Operators.FloatColumn; @@ -22,8 +21,9 @@ import parquet.filter2.predicate.Operators.Or; import parquet.filter2.predicate.Operators.SupportsEqNotEq; import parquet.filter2.predicate.Operators.SupportsLtGt; -import parquet.filter2.predicate.Operators.SimpleUserDefined; import parquet.filter2.predicate.Operators.UserDefined; +import parquet.filter2.predicate.Operators.UserDefinedByClass; +import parquet.filter2.predicate.Operators.UserDefinedByInstance; /** * The Filter API is expressed through these static methods. @@ -148,18 +148,23 @@ public static , C extends Column & SupportsLtGt> GtEq /** * Keeps records that pass the provided {@link UserDefinedPredicate} + * + * The provided class must have a default constructor. To use an instance + * of a UserDefinedPredicate instead, see {@link #userDefined(column, udp)} below. */ public static , U extends UserDefinedPredicate> UserDefined userDefined(Column column, Class clazz) { - return new SimpleUserDefined(column, clazz); + return new UserDefinedByClass(column, clazz); } /** - * Similar to above but allows to pass Serializable {@link UserDefinedPredicate} + * Keeps records that pass the provided {@link UserDefinedPredicate} + * + * The provided instance of UserDefinedPredicate must be serializable. */ public static , U extends UserDefinedPredicate & Serializable> UserDefined userDefined(Column column, U udp) { - return new ConfiguredUserDefined (column, udp); + return new UserDefinedByInstance(column, udp); } /** diff --git a/parquet-column/src/main/java/parquet/filter2/predicate/Operators.java b/parquet-column/src/main/java/parquet/filter2/predicate/Operators.java index 29e1c87275..d02e124364 100644 --- a/parquet-column/src/main/java/parquet/filter2/predicate/Operators.java +++ b/parquet-column/src/main/java/parquet/filter2/predicate/Operators.java @@ -342,14 +342,9 @@ public int hashCode() { public static abstract class UserDefined, U extends UserDefinedPredicate> implements FilterPredicate, Serializable { protected final Column column; - protected String toString; - private static final String INSTANTIATION_ERROR_MESSAGE = - "Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor."; UserDefined(Column column) { this.column = checkNotNull(column, "column"); - String name = getClass().getSimpleName().toLowerCase(); - this.toString = name + "(" + column.getColumnPath().toDotString() + ", UserDefined)"; } public Column getColumn() { @@ -359,39 +354,18 @@ public Column getColumn() { public abstract U getUserDefinedPredicate(); @Override - public abstract R accept(Visitor visitor); - - @Override - public String toString() { - return toString; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - UserDefined that = (UserDefined) o; - - if (!column.equals(that.column)) return false; - - return true; - } - - @Override - public int hashCode() { - int result = column.hashCode(); - result = result * 31 + getClass().hashCode(); - return result; + public R accept(Visitor visitor) { + return visitor.visit(this); } } - - public static final class SimpleUserDefined, U extends UserDefinedPredicate> extends UserDefined { + + public static final class UserDefinedByClass, U extends UserDefinedPredicate> extends UserDefined { private final Class udpClass; + private final String toString; private static final String INSTANTIATION_ERROR_MESSAGE = - "Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor."; + "Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor."; - SimpleUserDefined(Column column, Class udpClass) { + UserDefinedByClass(Column column, Class udpClass) { super(column); this.udpClass = checkNotNull(udpClass, "udpClass"); String name = getClass().getSimpleName().toLowerCase(); @@ -401,14 +375,11 @@ public static final class SimpleUserDefined, U extends U getUserDefinedPredicate(); } - public Column getColumn() { - return column; - } - public Class getUserDefinedPredicateClass() { return udpClass; } + @Override public U getUserDefinedPredicate() { try { return udpClass.newInstance(); @@ -419,11 +390,6 @@ public U getUserDefinedPredicate() { } } - @Override - public R accept(Visitor visitor) { - return visitor.visit(this); - } - @Override public String toString() { return toString; @@ -434,7 +400,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - SimpleUserDefined that = (SimpleUserDefined) o; + UserDefinedByClass that = (UserDefinedByClass) o; if (!column.equals(that.column)) return false; if (!udpClass.equals(that.udpClass)) return false; @@ -450,28 +416,21 @@ public int hashCode() { return result; } } - - public static final class ConfiguredUserDefined, U extends UserDefinedPredicate & Serializable > extends UserDefined { - //private final Column column; - private final U udp; + + public static final class UserDefinedByInstance, U extends UserDefinedPredicate & Serializable> extends UserDefined { private final String toString; + private final U udpInstance; - ConfiguredUserDefined(Column column, U udp) { - //column = checkNotNull(column, "column"); + UserDefinedByInstance(Column column, U udpInstance) { super(column); - this.udp = checkNotNull(udp, "udp"); + this.udpInstance = checkNotNull(udpInstance, "udpInstance"); String name = getClass().getSimpleName().toLowerCase(); - this.toString = name + "(" + column.getColumnPath().toDotString() + ", " + udp.getClass().getName() + ")"; + this.toString = name + "(" + column.getColumnPath().toDotString() + ", " + udpInstance + ")"; } @Override public U getUserDefinedPredicate() { - return udp; - } - - @Override - public R accept(Visitor visitor) { - return visitor.visit(this); + return udpInstance; } @Override @@ -484,10 +443,10 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - ConfiguredUserDefined that = (ConfiguredUserDefined) o; + UserDefinedByInstance that = (UserDefinedByInstance) o; if (!column.equals(that.column)) return false; - if (!udp.equals(that.udp)) return false; + if (!udpInstance.equals(that.udpInstance)) return false; return true; } @@ -495,7 +454,7 @@ public boolean equals(Object o) { @Override public int hashCode() { int result = column.hashCode(); - result = 31 * result + udp.hashCode(); + result = 31 * result + udpInstance.hashCode(); result = result * 31 + getClass().hashCode(); return result; } @@ -545,4 +504,5 @@ public int hashCode() { return result; } } + } diff --git a/parquet-column/src/test/java/parquet/filter2/predicate/TestFilterApiMethods.java b/parquet-column/src/test/java/parquet/filter2/predicate/TestFilterApiMethods.java index 63a61c40ee..665db1e79f 100644 --- a/parquet-column/src/test/java/parquet/filter2/predicate/TestFilterApiMethods.java +++ b/parquet-column/src/test/java/parquet/filter2/predicate/TestFilterApiMethods.java @@ -4,6 +4,7 @@ import java.io.ByteArrayOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.Serializable; import org.junit.Test; @@ -14,10 +15,11 @@ import parquet.filter2.predicate.Operators.Eq; import parquet.filter2.predicate.Operators.Gt; import parquet.filter2.predicate.Operators.IntColumn; +import parquet.filter2.predicate.Operators.LongColumn; import parquet.filter2.predicate.Operators.Not; import parquet.filter2.predicate.Operators.Or; -import parquet.filter2.predicate.Operators.SimpleUserDefined; import parquet.filter2.predicate.Operators.UserDefined; +import parquet.filter2.predicate.Operators.UserDefinedByClass; import parquet.io.api.Binary; import static org.junit.Assert.assertEquals; @@ -28,6 +30,7 @@ import static parquet.filter2.predicate.FilterApi.eq; import static parquet.filter2.predicate.FilterApi.gt; import static parquet.filter2.predicate.FilterApi.intColumn; +import static parquet.filter2.predicate.FilterApi.longColumn; import static parquet.filter2.predicate.FilterApi.not; import static parquet.filter2.predicate.FilterApi.notEq; import static parquet.filter2.predicate.FilterApi.or; @@ -37,6 +40,7 @@ public class TestFilterApiMethods { private static final IntColumn intColumn = intColumn("a.b.c"); + private static final LongColumn longColumn = longColumn("a.b.l"); private static final DoubleColumn doubleColumn = doubleColumn("x.y.z"); private static final BinaryColumn binColumn = binaryColumn("a.string.column"); @@ -83,15 +87,15 @@ public void testUdp() { FilterPredicate predicate = or(eq(doubleColumn, 12.0), userDefined(intColumn, DummyUdp.class)); assertTrue(predicate instanceof Or); FilterPredicate ud = ((Or) predicate).getRight(); - assertTrue(ud instanceof SimpleUserDefined); - assertEquals(DummyUdp.class, ((SimpleUserDefined) ud).getUserDefinedPredicateClass()); + assertTrue(ud instanceof UserDefinedByClass); + assertEquals(DummyUdp.class, ((UserDefinedByClass) ud).getUserDefinedPredicateClass()); assertTrue(((UserDefined) ud).getUserDefinedPredicate() instanceof DummyUdp); } @Test - public void testSerializable() throws Exception { + public void testSerializable() throws Exception { BinaryColumn binary = binaryColumn("foo"); - FilterPredicate p = or(and(userDefined(intColumn, DummyUdp.class), predicate), eq(binary, Binary.fromString("hi"))); + FilterPredicate p = and(or(and(userDefined(intColumn, DummyUdp.class), predicate), eq(binary, Binary.fromString("hi"))), userDefined(longColumn, new IsMultipleOf(7))); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(p); @@ -101,4 +105,50 @@ public void testSerializable() throws Exception { FilterPredicate read = (FilterPredicate) is.readObject(); assertEquals(p, read); } + + public static class IsMultipleOf extends UserDefinedPredicate implements Serializable { + + private long of; + + public IsMultipleOf(long of) { + this.of = of; + } + + @Override + public boolean keep(Long value) { + if (value == null) { + return false; + } + return value % of == 0; + } + + @Override + public boolean canDrop(Statistics statistics) { + return false; + } + + @Override + public boolean inverseCanDrop(Statistics statistics) { + return false; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + IsMultipleOf that = (IsMultipleOf) o; + return this.of == that.of; + } + + @Override + public int hashCode() { + return new Long(of).hashCode(); + } + + @Override + public String toString() { + return "IsMultipleOf(" + of + ")"; + } + } } diff --git a/parquet-hadoop/src/test/java/parquet/filter2/recordlevel/TestRecordLevelFilters.java b/parquet-hadoop/src/test/java/parquet/filter2/recordlevel/TestRecordLevelFilters.java index 9e969f3028..9e488f3859 100644 --- a/parquet-hadoop/src/test/java/parquet/filter2/recordlevel/TestRecordLevelFilters.java +++ b/parquet-hadoop/src/test/java/parquet/filter2/recordlevel/TestRecordLevelFilters.java @@ -168,7 +168,7 @@ public boolean inverseCanDrop(Statistics statistics) { public static class SetInFilter extends UserDefinedPredicate implements Serializable { - HashSet hSet; + private HashSet hSet; public SetInFilter(HashSet phSet) { hSet = phSet; @@ -211,12 +211,14 @@ public boolean keep(User u) { } @Test - public void testIdIn() throws Exception { + public void testUserDefinedByInstance() throws Exception { LongColumn name = longColumn("id"); - HashSet h = new HashSet() {{ - add(20L); add(27L); add(28L); - }}; + final HashSet h = new HashSet(); + h.add(20L); + h.add(27L); + h.add(28L); + FilterPredicate pred = userDefined(name, new SetInFilter(h)); List found = PhoneBookWriter.readFile(phonebookFile, FilterCompat.get(pred)); @@ -224,10 +226,7 @@ public void testIdIn() throws Exception { assertFilter(found, new UserFilter() { @Override public boolean keep(User u) { - Set h = new HashSet() {{ - add(20L); add(27L); add(28L); - }}; - return h.contains(u.getId()); + return u != null && h.contains(u.getId()); } }); } diff --git a/parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala b/parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala index 7e3997758a..7a07135fb0 100644 --- a/parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala +++ b/parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala @@ -1,6 +1,7 @@ package parquet.filter2.dsl import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong} +import java.io.Serializable import parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators, UserDefinedPredicate} import parquet.io.api.Binary @@ -30,6 +31,8 @@ object Dsl { val javaColumn: C def filterBy[U <: UserDefinedPredicate[T]](clazz: Class[U]) = FilterApi.userDefined(javaColumn, clazz) + + def filterBy[U <: UserDefinedPredicate[T] with Serializable](udp: U) = FilterApi.userDefined(javaColumn, udp) // this is not supported because it allows for easy mistakes. For example: // val pred = IntColumn("foo") == "hello" diff --git a/parquet-scala/src/test/scala/parquet/filter2/dsl/DslTest.scala b/parquet-scala/src/test/scala/parquet/filter2/dsl/DslTest.scala index f34fbd9bdf..e97726c39d 100644 --- a/parquet-scala/src/test/scala/parquet/filter2/dsl/DslTest.scala +++ b/parquet-scala/src/test/scala/parquet/filter2/dsl/DslTest.scala @@ -1,14 +1,15 @@ package parquet.filter2.dsl import java.lang.{Double => JDouble, Integer => JInt} +import java.io.Serializable import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner -import parquet.filter2.predicate.Operators.{Or, SimpleUserDefined, DoubleColumn => JDoubleColumn, IntColumn => JIntColumn} +import parquet.filter2.predicate.Operators.{Or, UserDefined, UserDefinedByClass, DoubleColumn => JDoubleColumn, IntColumn => JIntColumn} import parquet.filter2.predicate.{FilterApi, Statistics, UserDefinedPredicate} -class DummyFilter extends UserDefinedPredicate[JInt] { +class DummyFilter extends UserDefinedPredicate[JInt] with Serializable { override def keep(value: JInt): Boolean = false override def canDrop(statistics: Statistics[JInt]): Boolean = false @@ -37,14 +38,21 @@ class DslTest extends FlatSpec{ "user defined predicates" should "be correctly constructed" in { val abc = IntColumn("a.b.c") - val pred = (abc > 10) || abc.filterBy(classOf[DummyFilter]) + val predByClass = (abc > 10) || abc.filterBy(classOf[DummyFilter]) + val instance = new DummyFilter + val predByInstance = (abc > 10) || abc.filterBy(instance) - val expected = FilterApi.or(FilterApi.gt[JInt, JIntColumn](abc.javaColumn, 10), FilterApi.userDefined(abc.javaColumn, classOf[DummyFilter])) - assert(pred === expected) - val intUserDefined = pred.asInstanceOf[Or].getRight.asInstanceOf[SimpleUserDefined[JInt, DummyFilter]] - - assert(intUserDefined.getUserDefinedPredicateClass === classOf[DummyFilter]) - assert(intUserDefined.getUserDefinedPredicate.isInstanceOf[DummyFilter]) + val expectedByClass = FilterApi.or(FilterApi.gt[JInt, JIntColumn](abc.javaColumn, 10), FilterApi.userDefined(abc.javaColumn, classOf[DummyFilter])) + val expectedByInstance = FilterApi.or(FilterApi.gt[JInt, JIntColumn](abc.javaColumn, 10), FilterApi.userDefined(abc.javaColumn, instance)) + assert(predByClass === expectedByClass) + assert(predByInstance === expectedByInstance) + + val intUserDefinedByClass = predByClass.asInstanceOf[Or].getRight.asInstanceOf[UserDefinedByClass[JInt, DummyFilter]] + assert(intUserDefinedByClass.getUserDefinedPredicateClass === classOf[DummyFilter]) + assert(intUserDefinedByClass.getUserDefinedPredicate.isInstanceOf[DummyFilter]) + + val intUserDefinedByInstance = predByInstance.asInstanceOf[Or].getRight.asInstanceOf[UserDefined[JInt, DummyFilter]] + assert(intUserDefinedByInstance.getUserDefinedPredicate === instance) } "Column == and != " should "throw a helpful warning" in {