diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/ConcatTransform.java b/paimon-common/src/main/java/org/apache/paimon/predicate/ConcatTransform.java index 4c99df2dfc0d..a220beca9f4c 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/ConcatTransform.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/ConcatTransform.java @@ -35,4 +35,9 @@ public ConcatTransform(List inputs) { public BinaryString transform(List inputs) { return BinaryString.concat(inputs); } + + @Override + public Transform withNewInputs(List inputs) { + return new ConcatTransform(inputs); + } } diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/ConcatWsTransform.java b/paimon-common/src/main/java/org/apache/paimon/predicate/ConcatWsTransform.java index b121799cd8eb..5c7df4f47d22 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/ConcatWsTransform.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/ConcatWsTransform.java @@ -39,4 +39,9 @@ public BinaryString transform(List inputs) { BinaryString separator = inputs.get(0); return BinaryString.concatWs(separator, inputs.subList(1, inputs.size())); } + + @Override + public Transform withNewInputs(List inputs) { + return new ConcatWsTransform(inputs); + } } diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/FieldTransform.java b/paimon-common/src/main/java/org/apache/paimon/predicate/FieldTransform.java new file mode 100644 index 000000000000..4b842c39baa6 --- /dev/null +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/FieldTransform.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.predicate; + +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.types.DataType; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.apache.paimon.utils.InternalRowUtils.get; + +/** Transform that extracts a field from a row. */ +public class FieldTransform implements Transform { + + private final FieldRef fieldRef; + + public FieldTransform(FieldRef fieldRef) { + this.fieldRef = fieldRef; + } + + public FieldRef fieldRef() { + return fieldRef; + } + + @Override + public List inputs() { + return Collections.singletonList(fieldRef); + } + + @Override + public DataType outputType() { + return fieldRef.type(); + } + + @Override + public Object transform(InternalRow row) { + return get(row, fieldRef.index(), fieldRef.type()); + } + + @Override + public Transform withNewInputs(List inputs) { + assert inputs.size() == 1; + return new FieldTransform((FieldRef) inputs.get(0)); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + FieldTransform that = (FieldTransform) o; + return Objects.equals(fieldRef, that.fieldRef); + } + + @Override + public int hashCode() { + return Objects.hashCode(fieldRef); + } +} diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/LeafPredicate.java b/paimon-common/src/main/java/org/apache/paimon/predicate/LeafPredicate.java index 5267f7069f90..0ec9c03a3f67 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/LeafPredicate.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/LeafPredicate.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -37,28 +38,22 @@ import static org.apache.paimon.utils.InternalRowUtils.get; /** Leaf node of a {@link Predicate} tree. Compares a field in the row with literals. */ -public class LeafPredicate implements Predicate { +public class LeafPredicate extends TransformPredicate { private static final long serialVersionUID = 1L; - private final LeafFunction function; - private final DataType type; - private final int fieldIndex; - private final String fieldName; - - private transient List literals; - public LeafPredicate( LeafFunction function, DataType type, int fieldIndex, String fieldName, List literals) { - this.function = function; - this.type = type; - this.fieldIndex = fieldIndex; - this.fieldName = fieldName; - this.literals = literals; + this(new FieldTransform(new FieldRef(fieldIndex, fieldName, type)), function, literals); + } + + public LeafPredicate( + FieldTransform fieldTransform, LeafFunction function, List literals) { + super(fieldTransform, function, literals); } public LeafFunction function() { @@ -66,19 +61,23 @@ public LeafFunction function() { } public DataType type() { - return type; + return fieldRef().type(); } public int index() { - return fieldIndex; + return fieldRef().index(); } public String fieldName() { - return fieldName; + return fieldRef().name(); + } + + public List fieldNames() { + return Collections.singletonList(fieldRef().name()); } public FieldRef fieldRef() { - return new FieldRef(fieldIndex, fieldName, type); + return ((FieldTransform) transform).fieldRef(); } public List literals() { @@ -86,20 +85,15 @@ public List literals() { } public LeafPredicate copyWithNewIndex(int fieldIndex) { - return new LeafPredicate(function, type, fieldIndex, fieldName, literals); - } - - @Override - public boolean test(InternalRow row) { - return function.test(type, get(row, fieldIndex, type), literals); + return new LeafPredicate(function, type(), fieldIndex, fieldName(), literals); } @Override public boolean test( long rowCount, InternalRow minValues, InternalRow maxValues, InternalArray nullCounts) { - Object min = get(minValues, fieldIndex, type); - Object max = get(maxValues, fieldIndex, type); - Long nullCount = nullCounts.isNullAt(fieldIndex) ? null : nullCounts.getLong(fieldIndex); + Object min = get(minValues, index(), type()); + Object max = get(maxValues, index(), type()); + Long nullCount = nullCounts.isNullAt(index()) ? null : nullCounts.getLong(index()); if (nullCount == null || rowCount != nullCount) { // not all null // min or max is null @@ -108,13 +102,13 @@ public boolean test( return true; } } - return function.test(type, rowCount, min, max, nullCount, literals); + return function.test(type(), rowCount, min, max, nullCount, literals); } @Override public Optional negate() { return function.negate() - .map(negate -> new LeafPredicate(negate, type, fieldIndex, fieldName, literals)); + .map(negate -> new LeafPredicate(negate, type(), index(), fieldName(), literals)); } @Override @@ -122,27 +116,6 @@ public T visit(PredicateVisitor visitor) { return visitor.visit(this); } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - LeafPredicate that = (LeafPredicate) o; - return fieldIndex == that.fieldIndex - && Objects.equals(fieldName, that.fieldName) - && Objects.equals(function, that.function) - && Objects.equals(type, that.type) - && Objects.equals(literals, that.literals); - } - - @Override - public int hashCode() { - return Objects.hash(function, type, fieldIndex, fieldName, literals); - } - @Override public String toString() { String literalsStr; @@ -154,13 +127,13 @@ public String toString() { literalsStr = literals.toString(); } return literalsStr.isEmpty() - ? function + "(" + fieldName + ")" - : function + "(" + fieldName + ", " + literalsStr + ")"; + ? function + "(" + fieldName() + ")" + : function + "(" + fieldName() + ", " + literalsStr + ")"; } private ListSerializer objectsSerializer() { return new ListSerializer<>( - NullableSerializer.wrapIfNullIsNotSupported(InternalSerializers.create(type))); + NullableSerializer.wrapIfNullIsNotSupported(InternalSerializers.create(type()))); } private void writeObject(ObjectOutputStream out) throws IOException { diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/PredicateBuilder.java b/paimon-common/src/main/java/org/apache/paimon/predicate/PredicateBuilder.java index 326b82aecbd9..26de3686bb4c 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/PredicateBuilder.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/PredicateBuilder.java @@ -77,57 +77,109 @@ public Predicate equal(int idx, Object literal) { return leaf(Equal.INSTANCE, idx, literal); } + public Predicate equal(Transform transform, Object literal) { + return leaf(Equal.INSTANCE, transform, literal); + } + public Predicate notEqual(int idx, Object literal) { return leaf(NotEqual.INSTANCE, idx, literal); } + public Predicate notEqual(Transform transform, Object literal) { + return leaf(NotEqual.INSTANCE, transform, literal); + } + public Predicate lessThan(int idx, Object literal) { return leaf(LessThan.INSTANCE, idx, literal); } + public Predicate lessThan(Transform transform, Object literal) { + return leaf(LessThan.INSTANCE, transform, literal); + } + public Predicate lessOrEqual(int idx, Object literal) { return leaf(LessOrEqual.INSTANCE, idx, literal); } + public Predicate lessOrEqual(Transform transform, Object literal) { + return leaf(LessOrEqual.INSTANCE, transform, literal); + } + public Predicate greaterThan(int idx, Object literal) { return leaf(GreaterThan.INSTANCE, idx, literal); } + public Predicate greaterThan(Transform transform, Object literal) { + return leaf(GreaterThan.INSTANCE, transform, literal); + } + public Predicate greaterOrEqual(int idx, Object literal) { return leaf(GreaterOrEqual.INSTANCE, idx, literal); } + public Predicate greaterOrEqual(Transform transform, Object literal) { + return leaf(GreaterOrEqual.INSTANCE, transform, literal); + } + public Predicate isNull(int idx) { return leaf(IsNull.INSTANCE, idx); } + public Predicate isNull(Transform transform) { + return leaf(IsNull.INSTANCE, transform); + } + public Predicate isNotNull(int idx) { return leaf(IsNotNull.INSTANCE, idx); } + public Predicate isNotNull(Transform transform) { + return leaf(IsNotNull.INSTANCE, transform); + } + public Predicate startsWith(int idx, Object patternLiteral) { return leaf(StartsWith.INSTANCE, idx, patternLiteral); } + public Predicate startsWith(Transform transform, Object patternLiteral) { + return leaf(StartsWith.INSTANCE, transform, patternLiteral); + } + public Predicate endsWith(int idx, Object patternLiteral) { return leaf(EndsWith.INSTANCE, idx, patternLiteral); } + public Predicate endsWith(Transform transform, Object patternLiteral) { + return leaf(EndsWith.INSTANCE, transform, patternLiteral); + } + public Predicate contains(int idx, Object patternLiteral) { return leaf(Contains.INSTANCE, idx, patternLiteral); } - public Predicate leaf(NullFalseLeafBinaryFunction function, int idx, Object literal) { + public Predicate contains(Transform transform, Object patternLiteral) { + return leaf(Contains.INSTANCE, transform, patternLiteral); + } + + private Predicate leaf(NullFalseLeafBinaryFunction function, int idx, Object literal) { DataField field = rowType.getFields().get(idx); return new LeafPredicate(function, field.type(), idx, field.name(), singletonList(literal)); } - public Predicate leaf(LeafUnaryFunction function, int idx) { + private Predicate leaf(LeafFunction function, Transform transform, Object literal) { + return TransformPredicate.of(transform, function, singletonList(literal)); + } + + private Predicate leaf(LeafUnaryFunction function, int idx) { DataField field = rowType.getFields().get(idx); return new LeafPredicate( function, field.type(), idx, field.name(), Collections.emptyList()); } + private Predicate leaf(LeafFunction function, Transform transform) { + return TransformPredicate.of(transform, function, Collections.emptyList()); + } + public Predicate in(int idx, List literals) { // In the IN predicate, 20 literals are critical for performance. // If there are more than 20 literals, the performance will decrease. @@ -143,6 +195,20 @@ public Predicate in(int idx, List literals) { return or(equals); } + public Predicate in(Transform transform, List literals) { + // In the IN predicate, 20 literals are critical for performance. + // If there are more than 20 literals, the performance will decrease. + if (literals.size() > 20) { + return TransformPredicate.of(transform, In.INSTANCE, literals); + } + + List equals = new ArrayList<>(literals.size()); + for (Object literal : literals) { + equals.add(equal(transform, literal)); + } + return or(equals); + } + public Predicate notIn(int idx, List literals) { return in(idx, literals).negate().get(); } @@ -155,6 +221,15 @@ public Predicate between(int idx, Object includedLowerBound, Object includedUppe lessOrEqual(idx, includedUpperBound))); } + public Predicate between( + Transform transform, Object includedLowerBound, Object includedUpperBound) { + return new CompoundPredicate( + And.INSTANCE, + Arrays.asList( + greaterOrEqual(transform, includedLowerBound), + lessOrEqual(transform, includedUpperBound))); + } + public static Predicate and(Predicate... predicates) { return and(Arrays.asList(predicates)); } @@ -366,20 +441,26 @@ public static Optional transformFieldMapping( } } return Optional.of(new CompoundPredicate(compoundPredicate.function(), children)); - } else { - LeafPredicate leafPredicate = (LeafPredicate) predicate; - int mapped = fieldIdxMapping[leafPredicate.index()]; - if (mapped >= 0) { - return Optional.of( - new LeafPredicate( - leafPredicate.function(), - leafPredicate.type(), - mapped, - leafPredicate.fieldName(), - leafPredicate.literals())); - } else { - return Optional.empty(); + } else if (predicate instanceof TransformPredicate) { + TransformPredicate transformPredicate = (TransformPredicate) predicate; + List inputs = transformPredicate.transform.inputs(); + List newInputs = new ArrayList<>(inputs.size()); + for (Object input : inputs) { + if (input instanceof FieldRef) { + FieldRef fieldRef = (FieldRef) input; + int mappedIndex = fieldIdxMapping[fieldRef.index()]; + if (mappedIndex >= 0) { + newInputs.add(new FieldRef(mappedIndex, fieldRef.name(), fieldRef.type())); + } else { + return Optional.empty(); + } + } else { + newInputs.add(input); + } } + return Optional.of(transformPredicate.withNewInputs(newInputs)); + } else { + return Optional.empty(); } } @@ -392,8 +473,8 @@ public static boolean containsFields(Predicate predicate, Set fields) { } return false; } else { - LeafPredicate leafPredicate = (LeafPredicate) predicate; - return fields.contains(leafPredicate.fieldName()); + TransformPredicate transformPredicate = (TransformPredicate) predicate; + return fields.containsAll(transformPredicate.fieldNames()); } } diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/Transform.java b/paimon-common/src/main/java/org/apache/paimon/predicate/Transform.java index 3ab5c97e072d..1324133c7ec1 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/Transform.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/Transform.java @@ -32,4 +32,6 @@ public interface Transform extends Serializable { DataType outputType(); Object transform(InternalRow row); + + Transform withNewInputs(List inputs); } diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/TransformPredicate.java b/paimon-common/src/main/java/org/apache/paimon/predicate/TransformPredicate.java index 266bb517d936..94a1600359f5 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/TransformPredicate.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/TransformPredicate.java @@ -29,6 +29,7 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -38,20 +39,43 @@ public class TransformPredicate implements Predicate { private static final long serialVersionUID = 1L; - private final Transform transform; - private final LeafFunction function; - private transient List literals; + protected final Transform transform; + protected final LeafFunction function; + protected transient List literals; - public TransformPredicate(Transform transform, LeafFunction function, List literals) { + protected TransformPredicate( + Transform transform, LeafFunction function, List literals) { this.transform = transform; this.function = function; this.literals = literals; } + public static TransformPredicate of( + Transform transform, LeafFunction function, List literals) { + if (transform instanceof FieldTransform) { + return new LeafPredicate((FieldTransform) transform, function, literals); + } + return new TransformPredicate(transform, function, literals); + } + public Transform transform() { return transform; } + public TransformPredicate withNewInputs(List newInputs) { + return TransformPredicate.of(transform.withNewInputs(newInputs), function, literals); + } + + public List fieldNames() { + List names = new ArrayList<>(); + for (Object input : transform.inputs()) { + if (input instanceof FieldRef) { + names.add(((FieldRef) input).name()); + } + } + return names; + } + @Override public boolean test(InternalRow row) { Object value = transform.transform(row); diff --git a/paimon-common/src/test/java/org/apache/paimon/predicate/TransformPredicateTest.java b/paimon-common/src/test/java/org/apache/paimon/predicate/TransformPredicateTest.java index 90d144f8d0fa..c2f6fa32be43 100644 --- a/paimon-common/src/test/java/org/apache/paimon/predicate/TransformPredicateTest.java +++ b/paimon-common/src/test/java/org/apache/paimon/predicate/TransformPredicateTest.java @@ -76,6 +76,6 @@ private TransformPredicate create() { ConcatTransform transform = new ConcatTransform(inputs); List literals = new ArrayList<>(); literals.add(BinaryString.fromString("ha-he")); - return new TransformPredicate(transform, Equal.INSTANCE, literals); + return TransformPredicate.of(transform, Equal.INSTANCE, literals); } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBasePushDown.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBasePushDown.scala index f56e7b686470..3f1020260772 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBasePushDown.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBasePushDown.scala @@ -23,7 +23,7 @@ import org.apache.paimon.types.RowType import org.apache.spark.sql.PaimonUtils import org.apache.spark.sql.connector.expressions.filter.{Predicate => SparkPredicate} -import org.apache.spark.sql.connector.read.{SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{SupportsPushDownLimit, SupportsPushDownV2Filters} import org.apache.spark.sql.sources.Filter import java.util.{List => JList} @@ -56,7 +56,7 @@ trait PaimonBasePushDown extends SupportsPushDownV2Filters with SupportsPushDown pushable.append((predicate, paimonPredicate)) if (paimonPredicate.visit(visitor)) { // We need to filter the stats using filter instead of predicate. - reserved.append(PaimonUtils.filterV2ToV1(predicate).get) + PaimonUtils.filterV2ToV1(predicate).map(reserved.append(_)) } else { postScan.append(predicate) } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala index 70fffc4d3708..d4f0b0cfe043 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala @@ -27,8 +27,6 @@ import org.apache.paimon.spark.statistics.StatisticsHelper import org.apache.paimon.table.{DataTable, InnerTable} import org.apache.paimon.table.source.{InnerTableScan, Split} import org.apache.paimon.table.source.snapshot.TimeTravelUtil -import org.apache.paimon.table.system.FilesTable -import org.apache.paimon.utils.{SnapshotManager, TagManager} import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read.{Batch, Scan, Statistics, SupportsReportStatistics} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index d3ff6d1a4099..8d8d1825ba1b 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -191,7 +191,8 @@ case class PaimonScan( val converter = SparkV2FilterConverter(table.rowType()) val partitionKeys = table.partitionKeys().asScala.toSeq val partitionFilter = predicates.flatMap { - case p if SparkV2FilterConverter.isSupportedRuntimeFilter(p, partitionKeys) => + case p + if SparkV2FilterConverter(table.rowType()).isSupportedRuntimeFilter(p, partitionKeys) => converter.convert(p) case _ => None } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala index 87a8ddf99381..84937e2fdd61 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala @@ -18,21 +18,19 @@ package org.apache.paimon.spark -import org.apache.paimon.data.{BinaryString, Decimal, Timestamp} -import org.apache.paimon.predicate.{Predicate, PredicateBuilder} -import org.apache.paimon.spark.util.shim.TypeUtils.treatPaimonTimestampTypeAsSparkTimestampType -import org.apache.paimon.types.{DataTypeRoot, DecimalType, RowType} -import org.apache.paimon.types.DataTypeRoot._ +import org.apache.paimon.predicate.{FieldTransform, Predicate, PredicateBuilder, Transform} +import org.apache.paimon.spark.util.SparkExpressionConverter.{toPaimonLiteral, toPaimonTransform} +import org.apache.paimon.types.RowType import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.connector.expressions.{Literal, NamedReference} +import org.apache.spark.sql.connector.expressions.Expression +import org.apache.spark.sql.connector.expressions.Literal import org.apache.spark.sql.connector.expressions.filter.{And, Not, Or, Predicate => SparkPredicate} import scala.collection.JavaConverters._ /** Conversion from [[SparkPredicate]] to [[Predicate]]. */ -case class SparkV2FilterConverter(rowType: RowType) { +case class SparkV2FilterConverter(rowType: RowType) extends Logging { import org.apache.paimon.spark.SparkV2FilterConverter._ @@ -50,85 +48,78 @@ case class SparkV2FilterConverter(rowType: RowType) { private def convert(sparkPredicate: SparkPredicate): Predicate = { sparkPredicate.name() match { case EQUAL_TO => - BinaryPredicate.unapply(sparkPredicate) match { - case Some((fieldName, literal)) => + sparkPredicate match { + case BinaryPredicate(transform, literal) => // TODO deal with isNaN - val index = fieldIndex(fieldName) - builder.equal(index, convertLiteral(index, literal)) + builder.equal(transform, literal) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } case EQUAL_NULL_SAFE => - BinaryPredicate.unapply(sparkPredicate) match { - case Some((fieldName, literal)) => - val index = fieldIndex(fieldName) + sparkPredicate match { + case BinaryPredicate(transform, literal) => if (literal == null) { - builder.isNull(index) + builder.isNull(transform) } else { - builder.equal(index, convertLiteral(index, literal)) + builder.equal(transform, literal) } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } case GREATER_THAN => - BinaryPredicate.unapply(sparkPredicate) match { - case Some((fieldName, literal)) => - val index = fieldIndex(fieldName) - builder.greaterThan(index, convertLiteral(index, literal)) + sparkPredicate match { + case BinaryPredicate(transform, literal) => + builder.greaterThan(transform, literal) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } case GREATER_THAN_OR_EQUAL => - BinaryPredicate.unapply(sparkPredicate) match { - case Some((fieldName, literal)) => - val index = fieldIndex(fieldName) - builder.greaterOrEqual(index, convertLiteral(index, literal)) + sparkPredicate match { + case BinaryPredicate((transform, literal)) => + builder.greaterOrEqual(transform, literal) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } case LESS_THAN => - BinaryPredicate.unapply(sparkPredicate) match { - case Some((fieldName, literal)) => - val index = fieldIndex(fieldName) - builder.lessThan(index, convertLiteral(index, literal)) + sparkPredicate match { + case BinaryPredicate(transform, literal) => + builder.lessThan(transform, literal) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } case LESS_THAN_OR_EQUAL => - BinaryPredicate.unapply(sparkPredicate) match { - case Some((fieldName, literal)) => - val index = fieldIndex(fieldName) - builder.lessOrEqual(index, convertLiteral(index, literal)) + sparkPredicate match { + case BinaryPredicate(transform, literal) => + builder.lessOrEqual(transform, literal) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } case IN => - MultiPredicate.unapply(sparkPredicate) match { - case Some((fieldName, literals)) => - val index = fieldIndex(fieldName) - builder.in(index, literals.map(convertLiteral(index, _)).toList.asJava) + sparkPredicate match { + case MultiPredicate(transform, literals) => + builder.in(transform, literals.toList.asJava) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } case IS_NULL => - UnaryPredicate.unapply(sparkPredicate) match { - case Some(fieldName) => - builder.isNull(fieldIndex(fieldName)) + sparkPredicate match { + case UnaryPredicate(transform) => + builder.isNull(transform) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } case IS_NOT_NULL => - UnaryPredicate.unapply(sparkPredicate) match { - case Some(fieldName) => - builder.isNotNull(fieldIndex(fieldName)) + sparkPredicate match { + case UnaryPredicate(transform) => + builder.isNotNull(transform) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } @@ -151,28 +142,25 @@ case class SparkV2FilterConverter(rowType: RowType) { } case STRING_START_WITH => - BinaryPredicate.unapply(sparkPredicate) match { - case Some((fieldName, literal)) => - val index = fieldIndex(fieldName) - builder.startsWith(index, convertLiteral(index, literal)) + sparkPredicate match { + case BinaryPredicate(transform, literal) => + builder.startsWith(transform, literal) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } case STRING_END_WITH => - BinaryPredicate.unapply(sparkPredicate) match { - case Some((fieldName, literal)) => - val index = fieldIndex(fieldName) - builder.endsWith(index, convertLiteral(index, literal)) + sparkPredicate match { + case BinaryPredicate(transform, literal) => + builder.endsWith(transform, literal) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } case STRING_CONTAINS => - BinaryPredicate.unapply(sparkPredicate) match { - case Some((fieldName, literal)) => - val index = fieldIndex(fieldName) - builder.contains(index, convertLiteral(index, literal)) + sparkPredicate match { + case BinaryPredicate(transform, literal) => + builder.contains(transform, literal) case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } @@ -182,106 +170,55 @@ case class SparkV2FilterConverter(rowType: RowType) { } } - private def fieldIndex(fieldName: String): Int = { - val index = rowType.getFieldIndex(fieldName) - // TODO: support nested field - if (index == -1) { - throw new UnsupportedOperationException(s"Nested field '$fieldName' is unsupported.") - } - index - } - - private def convertLiteral(index: Int, value: Any): AnyRef = { - if (value == null) { - return null - } - - val dataType = rowType.getTypeAt(index) - dataType.getTypeRoot match { - case BOOLEAN | BIGINT | DOUBLE | TINYINT | SMALLINT | INTEGER | FLOAT | DATE => - value.asInstanceOf[AnyRef] - case DataTypeRoot.VARCHAR => - BinaryString.fromString(value.toString) - case DataTypeRoot.DECIMAL => - val decimalType = dataType.asInstanceOf[DecimalType] - val precision = decimalType.getPrecision - val scale = decimalType.getScale - Decimal.fromBigDecimal( - value.asInstanceOf[org.apache.spark.sql.types.Decimal].toJavaBigDecimal, - precision, - scale) - case DataTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE => - Timestamp.fromMicros(value.asInstanceOf[Long]) - case DataTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE => - if (treatPaimonTimestampTypeAsSparkTimestampType()) { - Timestamp.fromSQLTimestamp(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long])) - } else { - Timestamp.fromMicros(value.asInstanceOf[Long]) - } - case _ => - throw new UnsupportedOperationException( - s"Convert value: $value to datatype: $dataType is unsupported.") - } - } -} - -object SparkV2FilterConverter extends Logging { - - private val EQUAL_TO = "=" - private val EQUAL_NULL_SAFE = "<=>" - private val GREATER_THAN = ">" - private val GREATER_THAN_OR_EQUAL = ">=" - private val LESS_THAN = "<" - private val LESS_THAN_OR_EQUAL = "<=" - private val IN = "IN" - private val IS_NULL = "IS_NULL" - private val IS_NOT_NULL = "IS_NOT_NULL" - private val AND = "AND" - private val OR = "OR" - private val NOT = "NOT" - private val STRING_START_WITH = "STARTS_WITH" - private val STRING_END_WITH = "ENDS_WITH" - private val STRING_CONTAINS = "CONTAINS" - private object UnaryPredicate { - def unapply(sparkPredicate: SparkPredicate): Option[String] = { + def unapply(sparkPredicate: SparkPredicate): Option[Transform] = { sparkPredicate.children() match { - case Array(n: NamedReference) => Some(toFieldName(n)) + case Array(e: Expression) => toPaimonTransform(e, rowType) case _ => None } } } private object BinaryPredicate { - def unapply(sparkPredicate: SparkPredicate): Option[(String, Any)] = { + def unapply(sparkPredicate: SparkPredicate): Option[(Transform, Object)] = { sparkPredicate.children() match { - case Array(l: NamedReference, r: Literal[_]) => Some((toFieldName(l), r.value)) - case Array(l: Literal[_], r: NamedReference) => Some((toFieldName(r), l.value)) + case Array(e: Expression, r: Literal[_]) => + toPaimonTransform(e, rowType) match { + case Some(transform) => Some(transform, toPaimonLiteral(r)) + case _ => None + } case _ => None } } } private object MultiPredicate { - def unapply(sparkPredicate: SparkPredicate): Option[(String, Array[Any])] = { + def unapply(sparkPredicate: SparkPredicate): Option[(Transform, Seq[Object])] = { sparkPredicate.children() match { - case Array(first: NamedReference, rest @ _*) + case Array(e: Expression, rest @ _*) if rest.nonEmpty && rest.forall(_.isInstanceOf[Literal[_]]) => - Some(toFieldName(first), rest.map(_.asInstanceOf[Literal[_]].value).toArray) + val literals = rest.map(_.asInstanceOf[Literal[_]]) + if (literals.forall(_.dataType() == literals.head.dataType())) { + toPaimonTransform(e, rowType) match { + case Some(transform) => Some(transform, literals.map(toPaimonLiteral)) + case _ => None + } + } else { + None + } case _ => None } } } - private def toFieldName(ref: NamedReference): String = ref.fieldNames().mkString(".") - def isSupportedRuntimeFilter( sparkPredicate: SparkPredicate, partitionKeys: Seq[String]): Boolean = { sparkPredicate.name() match { case IN => - MultiPredicate.unapply(sparkPredicate) match { - case Some((fieldName, _)) => partitionKeys.contains(fieldName) + sparkPredicate match { + case MultiPredicate(transform: FieldTransform, _) => + partitionKeys.contains(transform.fieldRef().name()) case _ => logWarning(s"Convert $sparkPredicate is unsupported.") false @@ -290,3 +227,23 @@ object SparkV2FilterConverter extends Logging { } } } + +object SparkV2FilterConverter extends Logging { + + private val EQUAL_TO = "=" + private val EQUAL_NULL_SAFE = "<=>" + private val GREATER_THAN = ">" + private val GREATER_THAN_OR_EQUAL = ">=" + private val LESS_THAN = "<" + private val LESS_THAN_OR_EQUAL = "<=" + private val IN = "IN" + private val IS_NULL = "IS_NULL" + private val IS_NOT_NULL = "IS_NOT_NULL" + private val AND = "AND" + private val OR = "OR" + private val NOT = "NOT" + private val STRING_START_WITH = "STARTS_WITH" + private val STRING_END_WITH = "ENDS_WITH" + private val STRING_CONTAINS = "CONTAINS" + +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala index 8bc976b8669d..2bed4af87376 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala @@ -26,7 +26,6 @@ import org.apache.paimon.types.RowType import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.sql.PaimonUtils.{normalizeExprs, translateFilterV2} -import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, Cast, Expression, GetStructField, Literal, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.optimizer.ConstantFolding diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkExpressionConverter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkExpressionConverter.scala new file mode 100644 index 000000000000..b32436cabab5 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkExpressionConverter.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.util + +import org.apache.paimon.data.{BinaryString, Decimal, Timestamp} +import org.apache.paimon.predicate.{ConcatTransform, FieldRef, FieldTransform, Transform} +import org.apache.paimon.spark.SparkTypeUtils +import org.apache.paimon.spark.util.shim.TypeUtils.treatPaimonTimestampTypeAsSparkTimestampType +import org.apache.paimon.types.{DecimalType, RowType} +import org.apache.paimon.types.DataTypeRoot._ + +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal, NamedReference} + +import scala.collection.JavaConverters._ + +object SparkExpressionConverter { + + // Supported transform names + private val CONCAT = "CONCAT" + + /** Convert Spark [[Expression]] to Paimon [[Transform]], return None if not supported. */ + def toPaimonTransform(exp: Expression, rowType: RowType): Option[Transform] = { + exp match { + case n: NamedReference => Some(new FieldTransform(toPaimonFieldRef(n, rowType))) + case s: GeneralScalarExpression => + s.name() match { + case CONCAT => + val inputs = exp.children().map { + case n: NamedReference => toPaimonFieldRef(n, rowType) + case l: Literal[_] => toPaimonLiteral(l) + case _ => return None + } + Some(new ConcatTransform(inputs.toList.asJava)) + case _ => None + } + case _ => None + } + } + + /** Convert Spark [[Literal]] to Paimon literal. */ + def toPaimonLiteral(literal: Literal[_]): Object = { + if (literal == null) { + return null + } + + if (literal.children().nonEmpty) { + throw new UnsupportedOperationException(s"Convert value: $literal is unsupported.") + } + + val dataType = SparkTypeUtils.toPaimonType(literal.dataType()) + val value = literal.value() + dataType.getTypeRoot match { + case BOOLEAN | BIGINT | DOUBLE | TINYINT | SMALLINT | INTEGER | FLOAT | DATE => + value.asInstanceOf[AnyRef] + case VARCHAR => + BinaryString.fromString(value.toString) + case DECIMAL => + val decimalType = dataType.asInstanceOf[DecimalType] + val precision = decimalType.getPrecision + val scale = decimalType.getScale + Decimal.fromBigDecimal( + value.asInstanceOf[org.apache.spark.sql.types.Decimal].toJavaBigDecimal, + precision, + scale) + case TIMESTAMP_WITH_LOCAL_TIME_ZONE => + Timestamp.fromMicros(value.asInstanceOf[Long]) + case TIMESTAMP_WITHOUT_TIME_ZONE => + if (treatPaimonTimestampTypeAsSparkTimestampType()) { + Timestamp.fromSQLTimestamp(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long])) + } else { + Timestamp.fromMicros(value.asInstanceOf[Long]) + } + case _ => + throw new UnsupportedOperationException( + s"Convert value: $value to datatype: $dataType is unsupported.") + } + } + + private def toPaimonFieldRef(ref: NamedReference, rowType: RowType): FieldRef = { + val fieldName = toFieldName(ref) + val f = rowType.getField(fieldName) + // Note: here should use fieldIndex instead of fieldId + val index = rowType.getFieldIndex(fieldName) + if (index == -1) { + throw new UnsupportedOperationException(s"Nested field '$fieldName' is unsupported.") + } + new FieldRef(index, f.name(), f.`type`()) + } + + private def toFieldName(ref: NamedReference): String = ref.fieldNames().mkString(".") +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/connector/catalog/PaimonCatalogUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/connector/catalog/PaimonCatalogUtils.scala index f330fed3f38d..94cbe0d00c79 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/connector/catalog/PaimonCatalogUtils.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/connector/catalog/PaimonCatalogUtils.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.connector.catalog import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkConf -import org.apache.spark.sql.{PaimonSparkSession, SparkSession} +import org.apache.spark.sql.PaimonSparkSession import org.apache.spark.sql.catalyst.catalog.ExternalCatalog -import org.apache.spark.sql.connector.catalog.CatalogV2Util import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.paimon.ReflectUtils diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala index 8759f99f00a2..5a4ff36b6c82 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala @@ -87,6 +87,48 @@ abstract class PaimonPushDownTestBase extends PaimonSparkTestBase { checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Row(3, "c", "p2") :: Nil) } + test(s"Paimon push down: apply CONCAT") { + // Spark support push down CONCAT since Spark 3.4. + if (gteqSpark3_4) { + withTable("t") { + sql( + """ + |CREATE TABLE t (id int, value int, year STRING, month STRING, day STRING, hour STRING) + |using paimon + |PARTITIONED BY (year, month, day, hour) + |""".stripMargin) + + sql(""" + |INSERT INTO t values + |(1, 100, '2024', '07', '15', '21'), + |(2, 200, '2025', '07', '15', '21'), + |(3, 300, '2025', '07', '16', '22'), + |(4, 400, '2025', '07', '16', '23'), + |(5, 440, '2025', '07', '16', '23'), + |(6, 500, '2025', '07', '17', '00'), + |(7, 600, '2025', '07', '17', '02') + |""".stripMargin) + + val q = + """ + |SELECT * FROM t + |WHERE CONCAT(year,'-',month,'-',day,'-',hour) BETWEEN '2025-07-16-21' AND '2025-07-17-01' + |ORDER BY id + |""".stripMargin + assert(!checkFilterExists(q)) + + checkAnswer( + spark.sql(q), + Seq( + Row(3, 300, "2025", "07", "16", "22"), + Row(4, 400, "2025", "07", "16", "23"), + Row(5, 440, "2025", "07", "16", "23"), + Row(6, 500, "2025", "07", "17", "00")) + ) + } + } + } + test("Paimon pushDown: limit for append-only tables with deletion vector") { withTable("dv_test") { spark.sql(