diff --git a/parquet-column/src/main/java/parquet/filter2/predicate/FilterApi.java b/parquet-column/src/main/java/parquet/filter2/predicate/FilterApi.java index 22e99702cc..4a4ad0b775 100644 --- a/parquet-column/src/main/java/parquet/filter2/predicate/FilterApi.java +++ b/parquet-column/src/main/java/parquet/filter2/predicate/FilterApi.java @@ -18,6 +18,8 @@ */ package parquet.filter2.predicate; +import java.io.Serializable; + import parquet.common.schema.ColumnPath; import parquet.filter2.predicate.Operators.And; import parquet.filter2.predicate.Operators.BinaryColumn; @@ -38,6 +40,8 @@ import parquet.filter2.predicate.Operators.SupportsEqNotEq; import parquet.filter2.predicate.Operators.SupportsLtGt; import parquet.filter2.predicate.Operators.UserDefined; +import parquet.filter2.predicate.Operators.UserDefinedByClass; +import parquet.filter2.predicate.Operators.UserDefinedByInstance; /** * The Filter API is expressed through these static methods. @@ -162,10 +166,23 @@ public static , C extends Column & SupportsLtGt> GtEq /** * Keeps records that pass the provided {@link UserDefinedPredicate} + * + * The provided class must have a default constructor. To use an instance + * of a UserDefinedPredicate instead, see {@link #userDefined(column, udp)} below. */ public static , U extends UserDefinedPredicate> UserDefined userDefined(Column column, Class clazz) { - return new UserDefined(column, clazz); + return new UserDefinedByClass(column, clazz); + } + + /** + * Keeps records that pass the provided {@link UserDefinedPredicate} + * + * The provided instance of UserDefinedPredicate must be serializable. + */ + public static , U extends UserDefinedPredicate & Serializable> + UserDefined userDefined(Column column, U udp) { + return new UserDefinedByInstance(column, udp); } /** diff --git a/parquet-column/src/main/java/parquet/filter2/predicate/Operators.java b/parquet-column/src/main/java/parquet/filter2/predicate/Operators.java index 61898a24b4..80c5a831da 100644 --- a/parquet-column/src/main/java/parquet/filter2/predicate/Operators.java +++ b/parquet-column/src/main/java/parquet/filter2/predicate/Operators.java @@ -358,15 +358,33 @@ public int hashCode() { } } - public static final class UserDefined, U extends UserDefinedPredicate> implements FilterPredicate, Serializable { - private final Column column; + public static abstract class UserDefined, U extends UserDefinedPredicate> implements FilterPredicate, Serializable { + protected final Column column; + + UserDefined(Column column) { + this.column = checkNotNull(column, "column"); + } + + public Column getColumn() { + return column; + } + + public abstract U getUserDefinedPredicate(); + + @Override + public R accept(Visitor visitor) { + return visitor.visit(this); + } + } + + public static final class UserDefinedByClass, U extends UserDefinedPredicate> extends UserDefined { private final Class udpClass; private final String toString; private static final String INSTANTIATION_ERROR_MESSAGE = "Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor."; - UserDefined(Column column, Class udpClass) { - this.column = checkNotNull(column, "column"); + UserDefinedByClass(Column column, Class udpClass) { + super(column); this.udpClass = checkNotNull(udpClass, "udpClass"); String name = getClass().getSimpleName().toLowerCase(); this.toString = name + "(" + column.getColumnPath().toDotString() + ", " + udpClass.getName() + ")"; @@ -375,14 +393,11 @@ public static final class UserDefined, U extends UserDef getUserDefinedPredicate(); } - public Column getColumn() { - return column; - } - public Class getUserDefinedPredicateClass() { return udpClass; } + @Override public U getUserDefinedPredicate() { try { return udpClass.newInstance(); @@ -394,8 +409,46 @@ public U getUserDefinedPredicate() { } @Override - public R accept(Visitor visitor) { - return visitor.visit(this); + public String toString() { + return toString; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + UserDefinedByClass that = (UserDefinedByClass) o; + + if (!column.equals(that.column)) return false; + if (!udpClass.equals(that.udpClass)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = column.hashCode(); + result = 31 * result + udpClass.hashCode(); + result = result * 31 + getClass().hashCode(); + return result; + } + } + + public static final class UserDefinedByInstance, U extends UserDefinedPredicate & Serializable> extends UserDefined { + private final String toString; + private final U udpInstance; + + UserDefinedByInstance(Column column, U udpInstance) { + super(column); + this.udpInstance = checkNotNull(udpInstance, "udpInstance"); + String name = getClass().getSimpleName().toLowerCase(); + this.toString = name + "(" + column.getColumnPath().toDotString() + ", " + udpInstance + ")"; + } + + @Override + public U getUserDefinedPredicate() { + return udpInstance; } @Override @@ -408,10 +461,10 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - UserDefined that = (UserDefined) o; + UserDefinedByInstance that = (UserDefinedByInstance) o; if (!column.equals(that.column)) return false; - if (!udpClass.equals(that.udpClass)) return false; + if (!udpInstance.equals(that.udpInstance)) return false; return true; } @@ -419,7 +472,7 @@ public boolean equals(Object o) { @Override public int hashCode() { int result = column.hashCode(); - result = 31 * result + udpClass.hashCode(); + result = 31 * result + udpInstance.hashCode(); result = result * 31 + getClass().hashCode(); return result; } diff --git a/parquet-column/src/main/java/parquet/filter2/predicate/UserDefinedPredicate.java b/parquet-column/src/main/java/parquet/filter2/predicate/UserDefinedPredicate.java index 4025450e05..e03c9459fa 100644 --- a/parquet-column/src/main/java/parquet/filter2/predicate/UserDefinedPredicate.java +++ b/parquet-column/src/main/java/parquet/filter2/predicate/UserDefinedPredicate.java @@ -105,4 +105,4 @@ public UserDefinedPredicate() { } * } */ public abstract boolean inverseCanDrop(Statistics statistics); -} \ No newline at end of file +} diff --git a/parquet-column/src/test/java/parquet/filter2/predicate/TestFilterApiMethods.java b/parquet-column/src/test/java/parquet/filter2/predicate/TestFilterApiMethods.java index a92d480fa1..849d946581 100644 --- a/parquet-column/src/test/java/parquet/filter2/predicate/TestFilterApiMethods.java +++ b/parquet-column/src/test/java/parquet/filter2/predicate/TestFilterApiMethods.java @@ -22,6 +22,7 @@ import java.io.ByteArrayOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.Serializable; import org.junit.Test; @@ -32,9 +33,11 @@ import parquet.filter2.predicate.Operators.Eq; import parquet.filter2.predicate.Operators.Gt; import parquet.filter2.predicate.Operators.IntColumn; +import parquet.filter2.predicate.Operators.LongColumn; import parquet.filter2.predicate.Operators.Not; import parquet.filter2.predicate.Operators.Or; import parquet.filter2.predicate.Operators.UserDefined; +import parquet.filter2.predicate.Operators.UserDefinedByClass; import parquet.io.api.Binary; import static org.junit.Assert.assertEquals; @@ -45,6 +48,7 @@ import static parquet.filter2.predicate.FilterApi.eq; import static parquet.filter2.predicate.FilterApi.gt; import static parquet.filter2.predicate.FilterApi.intColumn; +import static parquet.filter2.predicate.FilterApi.longColumn; import static parquet.filter2.predicate.FilterApi.not; import static parquet.filter2.predicate.FilterApi.notEq; import static parquet.filter2.predicate.FilterApi.or; @@ -54,6 +58,7 @@ public class TestFilterApiMethods { private static final IntColumn intColumn = intColumn("a.b.c"); + private static final LongColumn longColumn = longColumn("a.b.l"); private static final DoubleColumn doubleColumn = doubleColumn("x.y.z"); private static final BinaryColumn binColumn = binaryColumn("a.string.column"); @@ -100,15 +105,15 @@ public void testUdp() { FilterPredicate predicate = or(eq(doubleColumn, 12.0), userDefined(intColumn, DummyUdp.class)); assertTrue(predicate instanceof Or); FilterPredicate ud = ((Or) predicate).getRight(); - assertTrue(ud instanceof UserDefined); - assertEquals(DummyUdp.class, ((UserDefined) ud).getUserDefinedPredicateClass()); + assertTrue(ud instanceof UserDefinedByClass); + assertEquals(DummyUdp.class, ((UserDefinedByClass) ud).getUserDefinedPredicateClass()); assertTrue(((UserDefined) ud).getUserDefinedPredicate() instanceof DummyUdp); } @Test - public void testSerializable() throws Exception { + public void testSerializable() throws Exception { BinaryColumn binary = binaryColumn("foo"); - FilterPredicate p = or(and(userDefined(intColumn, DummyUdp.class), predicate), eq(binary, Binary.fromString("hi"))); + FilterPredicate p = and(or(and(userDefined(intColumn, DummyUdp.class), predicate), eq(binary, Binary.fromString("hi"))), userDefined(longColumn, new IsMultipleOf(7))); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(p); @@ -118,4 +123,50 @@ public void testSerializable() throws Exception { FilterPredicate read = (FilterPredicate) is.readObject(); assertEquals(p, read); } + + public static class IsMultipleOf extends UserDefinedPredicate implements Serializable { + + private long of; + + public IsMultipleOf(long of) { + this.of = of; + } + + @Override + public boolean keep(Long value) { + if (value == null) { + return false; + } + return value % of == 0; + } + + @Override + public boolean canDrop(Statistics statistics) { + return false; + } + + @Override + public boolean inverseCanDrop(Statistics statistics) { + return false; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + IsMultipleOf that = (IsMultipleOf) o; + return this.of == that.of; + } + + @Override + public int hashCode() { + return new Long(of).hashCode(); + } + + @Override + public String toString() { + return "IsMultipleOf(" + of + ")"; + } + } } diff --git a/parquet-hadoop/src/test/java/parquet/filter2/recordlevel/TestRecordLevelFilters.java b/parquet-hadoop/src/test/java/parquet/filter2/recordlevel/TestRecordLevelFilters.java index bec9f0be2c..c112bd9178 100644 --- a/parquet-hadoop/src/test/java/parquet/filter2/recordlevel/TestRecordLevelFilters.java +++ b/parquet-hadoop/src/test/java/parquet/filter2/recordlevel/TestRecordLevelFilters.java @@ -20,10 +20,13 @@ import java.io.File; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; +import java.util.Set; +import java.util.HashSet; import org.junit.BeforeClass; import org.junit.Test; @@ -33,6 +36,7 @@ import parquet.filter2.predicate.FilterPredicate; import parquet.filter2.predicate.Operators.BinaryColumn; import parquet.filter2.predicate.Operators.DoubleColumn; +import parquet.filter2.predicate.Operators.LongColumn; import parquet.filter2.predicate.Statistics; import parquet.filter2.predicate.UserDefinedPredicate; import parquet.filter2.recordlevel.PhoneBookWriter.Location; @@ -44,6 +48,7 @@ import static parquet.filter2.predicate.FilterApi.and; import static parquet.filter2.predicate.FilterApi.binaryColumn; import static parquet.filter2.predicate.FilterApi.doubleColumn; +import static parquet.filter2.predicate.FilterApi.longColumn; import static parquet.filter2.predicate.FilterApi.eq; import static parquet.filter2.predicate.FilterApi.gt; import static parquet.filter2.predicate.FilterApi.not; @@ -178,6 +183,34 @@ public boolean inverseCanDrop(Statistics statistics) { return false; } } + + public static class SetInFilter extends UserDefinedPredicate implements Serializable { + + private HashSet hSet; + + public SetInFilter(HashSet phSet) { + hSet = phSet; + } + + @Override + public boolean keep(Long value) { + if (value == null) { + return false; + } + + return hSet.contains(value); + } + + @Override + public boolean canDrop(Statistics statistics) { + return false; + } + + @Override + public boolean inverseCanDrop(Statistics statistics) { + return false; + } + } @Test public void testNameNotStartWithP() throws Exception { @@ -194,6 +227,27 @@ public boolean keep(User u) { } }); } + + @Test + public void testUserDefinedByInstance() throws Exception { + LongColumn name = longColumn("id"); + + final HashSet h = new HashSet(); + h.add(20L); + h.add(27L); + h.add(28L); + + FilterPredicate pred = userDefined(name, new SetInFilter(h)); + + List found = PhoneBookWriter.readFile(phonebookFile, FilterCompat.get(pred)); + + assertFilter(found, new UserFilter() { + @Override + public boolean keep(User u) { + return u != null && h.contains(u.getId()); + } + }); + } @Test public void testComplex() throws Exception { diff --git a/parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala b/parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala index 38e205ebde..8711300f99 100644 --- a/parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala +++ b/parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala @@ -19,6 +19,7 @@ package parquet.filter2.dsl import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong} +import java.io.Serializable import parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators, UserDefinedPredicate} import parquet.io.api.Binary @@ -48,6 +49,8 @@ object Dsl { val javaColumn: C def filterBy[U <: UserDefinedPredicate[T]](clazz: Class[U]) = FilterApi.userDefined(javaColumn, clazz) + + def filterBy[U <: UserDefinedPredicate[T] with Serializable](udp: U) = FilterApi.userDefined(javaColumn, udp) // this is not supported because it allows for easy mistakes. For example: // val pred = IntColumn("foo") == "hello" diff --git a/parquet-scala/src/test/scala/parquet/filter2/dsl/DslTest.scala b/parquet-scala/src/test/scala/parquet/filter2/dsl/DslTest.scala index d40367a9f2..eed9b522cd 100644 --- a/parquet-scala/src/test/scala/parquet/filter2/dsl/DslTest.scala +++ b/parquet-scala/src/test/scala/parquet/filter2/dsl/DslTest.scala @@ -19,14 +19,15 @@ package parquet.filter2.dsl import java.lang.{Double => JDouble, Integer => JInt} +import java.io.Serializable import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner -import parquet.filter2.predicate.Operators.{Or, UserDefined, DoubleColumn => JDoubleColumn, IntColumn => JIntColumn} +import parquet.filter2.predicate.Operators.{Or, UserDefined, UserDefinedByClass, DoubleColumn => JDoubleColumn, IntColumn => JIntColumn} import parquet.filter2.predicate.{FilterApi, Statistics, UserDefinedPredicate} -class DummyFilter extends UserDefinedPredicate[JInt] { +class DummyFilter extends UserDefinedPredicate[JInt] with Serializable { override def keep(value: JInt): Boolean = false override def canDrop(statistics: Statistics[JInt]): Boolean = false @@ -55,14 +56,21 @@ class DslTest extends FlatSpec{ "user defined predicates" should "be correctly constructed" in { val abc = IntColumn("a.b.c") - val pred = (abc > 10) || abc.filterBy(classOf[DummyFilter]) + val predByClass = (abc > 10) || abc.filterBy(classOf[DummyFilter]) + val instance = new DummyFilter + val predByInstance = (abc > 10) || abc.filterBy(instance) - val expected = FilterApi.or(FilterApi.gt[JInt, JIntColumn](abc.javaColumn, 10), FilterApi.userDefined(abc.javaColumn, classOf[DummyFilter])) - assert(pred === expected) - val intUserDefined = pred.asInstanceOf[Or].getRight.asInstanceOf[UserDefined[JInt, DummyFilter]] - - assert(intUserDefined.getUserDefinedPredicateClass === classOf[DummyFilter]) - assert(intUserDefined.getUserDefinedPredicate.isInstanceOf[DummyFilter]) + val expectedByClass = FilterApi.or(FilterApi.gt[JInt, JIntColumn](abc.javaColumn, 10), FilterApi.userDefined(abc.javaColumn, classOf[DummyFilter])) + val expectedByInstance = FilterApi.or(FilterApi.gt[JInt, JIntColumn](abc.javaColumn, 10), FilterApi.userDefined(abc.javaColumn, instance)) + assert(predByClass === expectedByClass) + assert(predByInstance === expectedByInstance) + + val intUserDefinedByClass = predByClass.asInstanceOf[Or].getRight.asInstanceOf[UserDefinedByClass[JInt, DummyFilter]] + assert(intUserDefinedByClass.getUserDefinedPredicateClass === classOf[DummyFilter]) + assert(intUserDefinedByClass.getUserDefinedPredicate.isInstanceOf[DummyFilter]) + + val intUserDefinedByInstance = predByInstance.asInstanceOf[Or].getRight.asInstanceOf[UserDefined[JInt, DummyFilter]] + assert(intUserDefinedByInstance.getUserDefinedPredicate === instance) } "Column == and != " should "throw a helpful warning" in { diff --git a/pom.xml b/pom.xml index 18f4fc89b1..a753b3519c 100644 --- a/pom.xml +++ b/pom.xml @@ -274,6 +274,7 @@ true ${previous.version} + parquet/filter2/** parquet/org/** parquet/io/api/Binary