diff --git a/api/src/main/java/org/apache/iceberg/TableScan.java b/api/src/main/java/org/apache/iceberg/TableScan.java index 5d2a1269d633..3c3c44b369de 100644 --- a/api/src/main/java/org/apache/iceberg/TableScan.java +++ b/api/src/main/java/org/apache/iceberg/TableScan.java @@ -101,4 +101,13 @@ default TableScan appendsAfter(long fromSnapshotId) { * @return the Snapshot this scan will use */ Snapshot snapshot(); + + /** + * Create a new {@link TableScan} from this scan's configuration that will have column stats + * + * @return a new scan based on this with column stats + */ + default TableScan withColStats() { + throw new UnsupportedOperationException("scan with colStats is not supported"); + } } diff --git a/api/src/main/java/org/apache/iceberg/expressions/AggregateEvaluator.java b/api/src/main/java/org/apache/iceberg/expressions/AggregateEvaluator.java new file mode 100644 index 000000000000..7c03796c249b --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/AggregateEvaluator.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.expressions.BoundAggregate.Aggregator; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; + +/** + * A class for evaluating aggregates. It evaluates each of the aggregates and updates the aggregated + * value. The final aggregated result can be returned by {@link #result()}. + */ +public class AggregateEvaluator { + + public static AggregateEvaluator create(Schema schema, List aggregates) { + return create(schema.asStruct(), aggregates); + } + + public static AggregateEvaluator create(List> aggregates) { + return new AggregateEvaluator(aggregates); + } + + private static AggregateEvaluator create(Types.StructType struct, List aggregates) { + List> boundAggregates = + aggregates.stream() + .map(expr -> Binder.bind(struct, expr)) + .map(bound -> (BoundAggregate) bound) + .collect(Collectors.toList()); + + return new AggregateEvaluator(boundAggregates); + } + + private final List> aggregators; + private final Types.StructType resultType; + private final List> aggregates; + + private AggregateEvaluator(List> aggregates) { + ImmutableList.Builder> aggregatorsBuilder = ImmutableList.builder(); + List resultFields = Lists.newArrayList(); + + for (int pos = 0; pos < aggregates.size(); pos += 1) { + BoundAggregate aggregate = aggregates.get(pos); + aggregatorsBuilder.add(aggregate.newAggregator()); + resultFields.add(Types.NestedField.optional(pos, aggregate.describe(), aggregate.type())); + } + + this.aggregators = aggregatorsBuilder.build(); + this.resultType = Types.StructType.of(resultFields); + this.aggregates = aggregates; + } + + public void update(StructLike struct) { + for (Aggregator aggregator : aggregators) { + aggregator.update(struct); + } + } + + public void update(DataFile file) { + for (Aggregator aggregator : aggregators) { + aggregator.update(file); + } + } + + public Types.StructType resultType() { + return resultType; + } + + public boolean allAggregatorsValid() { + return aggregators.stream().allMatch(BoundAggregate.Aggregator::isValid); + } + + public StructLike result() { + Object[] results = + aggregators.stream().map(BoundAggregate.Aggregator::result).toArray(Object[]::new); + return new ArrayStructLike(results); + } + + public List> aggregates() { + return aggregates; + } + + private static class ArrayStructLike implements StructLike { + private final Object[] values; + + private ArrayStructLike(Object[] values) { + this.values = values; + } + + public int size() { + return values.length; + } + + @Override + public T get(int pos, Class javaClass) { + return javaClass.cast(values[pos]); + } + + @Override + public void set(int pos, T value) { + values[pos] = value; + } + } +} diff --git a/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java index 650271b3b78a..f8db6eac2022 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java +++ b/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java @@ -18,18 +18,37 @@ */ package org.apache.iceberg.expressions; +import java.util.Map; +import org.apache.iceberg.DataFile; import org.apache.iceberg.StructLike; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; public class BoundAggregate extends Aggregate> implements Bound { + protected BoundAggregate(Operation op, BoundTerm term) { super(op, term); } @Override public C eval(StructLike struct) { - throw new UnsupportedOperationException(this.getClass().getName() + " does not implement eval"); + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement eval(StructLike)"); + } + + C eval(DataFile file) { + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement eval(DataFile)"); + } + + boolean hasValue(DataFile file) { + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement hasValue(DataFile)"); + } + + Aggregator newAggregator() { + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement newAggregator()"); } @Override @@ -44,4 +63,105 @@ public Type type() { return term().type(); } } + + public String columnName() { + if (op() == Operation.COUNT_STAR) { + return "*"; + } else { + return ref().name(); + } + } + + public String describe() { + switch (op()) { + case COUNT_STAR: + return "count(*)"; + case COUNT: + return "count(" + ExpressionUtil.describe(term()) + ")"; + case MAX: + return "max(" + ExpressionUtil.describe(term()) + ")"; + case MIN: + return "min(" + ExpressionUtil.describe(term()) + ")"; + default: + throw new UnsupportedOperationException("Unsupported aggregate type: " + op()); + } + } + + V safeGet(Map map, int key) { + return safeGet(map, key, null); + } + + V safeGet(Map map, int key, V defaultValue) { + if (map != null) { + return map.getOrDefault(key, defaultValue); + } + + return null; + } + + interface Aggregator { + void update(StructLike struct); + + void update(DataFile file); + + boolean hasValue(DataFile file); + + R result(); + + boolean isValid(); + } + + abstract static class NullSafeAggregator implements Aggregator { + private final BoundAggregate aggregate; + private boolean isValid = true; + + NullSafeAggregator(BoundAggregate aggregate) { + this.aggregate = aggregate; + } + + protected abstract void update(R value); + + protected abstract R current(); + + @Override + public void update(StructLike struct) { + R value = aggregate.eval(struct); + if (value != null) { + update(value); + } + } + + @Override + public boolean hasValue(DataFile file) { + return aggregate.hasValue(file); + } + + @Override + public void update(DataFile file) { + if (isValid) { + if (hasValue(file)) { + R value = aggregate.eval(file); + if (value != null) { + update(value); + } + } else { + this.isValid = false; + } + } + } + + @Override + public R result() { + if (!isValid) { + return null; + } + + return current(); + } + + @Override + public boolean isValid() { + return this.isValid; + } + } } diff --git a/api/src/main/java/org/apache/iceberg/expressions/CountAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/CountAggregate.java new file mode 100644 index 000000000000..5ad9c18e8d07 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/CountAggregate.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import org.apache.iceberg.DataFile; +import org.apache.iceberg.StructLike; + +public class CountAggregate extends BoundAggregate { + protected CountAggregate(Operation op, BoundTerm term) { + super(op, term); + } + + @Override + public Long eval(StructLike struct) { + return countFor(struct); + } + + @Override + public Long eval(DataFile file) { + return countFor(file); + } + + protected Long countFor(StructLike row) { + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement countFor(StructLike)"); + } + + protected Long countFor(DataFile file) { + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement countFor(DataFile)"); + } + + @Override + public Aggregator newAggregator() { + return new CountAggregator<>(this); + } + + private static class CountAggregator extends NullSafeAggregator { + private Long count = 0L; + + CountAggregator(BoundAggregate aggregate) { + super(aggregate); + } + + @Override + protected void update(Long value) { + count += value; + } + + @Override + protected Long current() { + return count; + } + } +} diff --git a/api/src/main/java/org/apache/iceberg/expressions/CountNonNull.java b/api/src/main/java/org/apache/iceberg/expressions/CountNonNull.java new file mode 100644 index 000000000000..10afd72e2e36 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/CountNonNull.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import org.apache.iceberg.DataFile; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.types.Types; + +public class CountNonNull extends CountAggregate { + private final int fieldId; + private final Types.NestedField field; + + protected CountNonNull(BoundTerm term) { + super(Operation.COUNT, term); + this.field = term.ref().field(); + this.fieldId = field.fieldId(); + } + + @Override + protected Long countFor(StructLike row) { + return term().eval(row) != null ? 1L : 0L; + } + + @Override + protected boolean hasValue(DataFile file) { + return file.valueCounts().containsKey(fieldId) && file.nullValueCounts().containsKey(fieldId); + } + + @Override + protected Long countFor(DataFile file) { + return safeSubtract( + safeGet(file.valueCounts(), fieldId), safeGet(file.nullValueCounts(), fieldId, 0L)); + } + + private Long safeSubtract(Long left, Long right) { + if (left != null && right != null) { + return left - right; + } + + return null; + } +} diff --git a/api/src/main/java/org/apache/iceberg/expressions/CountStar.java b/api/src/main/java/org/apache/iceberg/expressions/CountStar.java new file mode 100644 index 000000000000..01d29c11b594 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/CountStar.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import org.apache.iceberg.DataFile; +import org.apache.iceberg.StructLike; + +public class CountStar extends CountAggregate { + protected CountStar(BoundTerm term) { + super(Operation.COUNT_STAR, term); + } + + @Override + protected Long countFor(StructLike row) { + return 1L; + } + + @Override + protected boolean hasValue(DataFile file) { + return file.recordCount() >= 0; + } + + @Override + protected Long countFor(DataFile file) { + long count = file.recordCount(); + if (count < 0) { + return null; + } + + return count; + } +} diff --git a/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java b/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java index c910a77640ff..bff061968fb5 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java +++ b/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java @@ -156,6 +156,26 @@ public static boolean selectsPartitions( caseSensitive); } + public static String describe(Term term) { + if (term instanceof UnboundTransform) { + return ((UnboundTransform) term).transform() + + "(" + + describe(((UnboundTransform) term).ref()) + + ")"; + } else if (term instanceof BoundTransform) { + return ((BoundTransform) term).transform() + + "(" + + describe(((BoundTransform) term).ref()) + + ")"; + } else if (term instanceof NamedReference) { + return ((NamedReference) term).name(); + } else if (term instanceof BoundReference) { + return ((BoundReference) term).name(); + } else { + throw new UnsupportedOperationException("Unsupported term: " + term); + } + } + private static class ExpressionSanitizer extends ExpressionVisitors.ExpressionVisitor { private final long now; @@ -271,19 +291,9 @@ public String predicate(BoundPredicate pred) { throw new UnsupportedOperationException("Cannot sanitize bound predicate: " + pred); } - public String termToString(UnboundTerm term) { - if (term instanceof UnboundTransform) { - return ((UnboundTransform) term).transform() + "(" + termToString(term.ref()) + ")"; - } else if (term instanceof NamedReference) { - return ((NamedReference) term).name(); - } else { - throw new UnsupportedOperationException("Unsupported term: " + term); - } - } - @Override public String predicate(UnboundPredicate pred) { - String term = termToString(pred.term()); + String term = describe(pred.term()); switch (pred.op()) { case IS_NULL: return term + " IS NULL"; diff --git a/api/src/main/java/org/apache/iceberg/expressions/MaxAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/MaxAggregate.java new file mode 100644 index 000000000000..754da9046f5b --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/MaxAggregate.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import java.util.Comparator; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.types.Comparators; +import org.apache.iceberg.types.Conversions; +import org.apache.iceberg.types.Type.PrimitiveType; +import org.apache.iceberg.types.Types; + +public class MaxAggregate extends ValueAggregate { + private final int fieldId; + private final PrimitiveType type; + private final Comparator comparator; + + protected MaxAggregate(BoundTerm term) { + super(Operation.MAX, term); + Types.NestedField field = term.ref().field(); + this.fieldId = field.fieldId(); + this.type = field.type().asPrimitiveType(); + this.comparator = Comparators.forType(type); + } + + @Override + protected boolean hasValue(DataFile file) { + boolean hasBound = file.upperBounds().containsKey(fieldId); + Long valueCount = safeGet(file.valueCounts(), fieldId); + Long nullCount = safeGet(file.nullValueCounts(), fieldId); + boolean boundAllNull = + valueCount != null + && valueCount > 0 + && nullCount != null + && nullCount.longValue() == valueCount.longValue(); + return hasBound || boundAllNull; + } + + @Override + protected Object evaluateRef(DataFile file) { + return Conversions.fromByteBuffer(type, safeGet(file.upperBounds(), fieldId)); + } + + @Override + public Aggregator newAggregator() { + return new MaxAggregator<>(this, comparator); + } + + private static class MaxAggregator extends NullSafeAggregator { + private final Comparator comparator; + private T max = null; + + MaxAggregator(MaxAggregate aggregate, Comparator comparator) { + super(aggregate); + this.comparator = comparator; + } + + @Override + protected void update(T value) { + if (max == null || comparator.compare(value, max) > 0) { + this.max = value; + } + } + + @Override + protected T current() { + return max; + } + } +} diff --git a/api/src/main/java/org/apache/iceberg/expressions/MinAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/MinAggregate.java new file mode 100644 index 000000000000..a6bcea4145c3 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/MinAggregate.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import java.util.Comparator; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.types.Comparators; +import org.apache.iceberg.types.Conversions; +import org.apache.iceberg.types.Type.PrimitiveType; +import org.apache.iceberg.types.Types; + +public class MinAggregate extends ValueAggregate { + private final int fieldId; + private final PrimitiveType type; + private final Comparator comparator; + + protected MinAggregate(BoundTerm term) { + super(Operation.MIN, term); + Types.NestedField field = term.ref().field(); + this.fieldId = field.fieldId(); + this.type = field.type().asPrimitiveType(); + this.comparator = Comparators.forType(type); + } + + @Override + protected boolean hasValue(DataFile file) { + boolean hasBound = file.lowerBounds().containsKey(fieldId); + Long valueCount = safeGet(file.valueCounts(), fieldId); + Long nullCount = safeGet(file.nullValueCounts(), fieldId); + boolean boundAllNull = + valueCount != null + && valueCount > 0 + && nullCount != null + && nullCount.longValue() == valueCount.longValue(); + return hasBound || boundAllNull; + } + + @Override + protected Object evaluateRef(DataFile file) { + return Conversions.fromByteBuffer(type, safeGet(file.lowerBounds(), fieldId)); + } + + @Override + public Aggregator newAggregator() { + return new MinAggregator<>(this, comparator); + } + + private static class MinAggregator extends NullSafeAggregator { + private final Comparator comparator; + private T min = null; + + MinAggregator(MinAggregate aggregate, Comparator comparator) { + super(aggregate); + this.comparator = comparator; + } + + @Override + protected void update(T value) { + if (min == null || comparator.compare(value, min) < 0) { + this.min = value; + } + } + + @Override + protected T current() { + return min; + } + } +} diff --git a/api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java index 5e4cce06c7e8..65e469a631b1 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java +++ b/api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java @@ -46,12 +46,22 @@ public NamedReference ref() { */ @Override public Expression bind(Types.StructType struct, boolean caseSensitive) { - if (op() == Operation.COUNT_STAR) { - return new BoundAggregate<>(op(), null); - } else { - Preconditions.checkArgument(term() != null, "Invalid aggregate term: null"); - BoundTerm bound = term().bind(struct, caseSensitive); - return new BoundAggregate<>(op(), bound); + switch (op()) { + case COUNT_STAR: + return new CountStar<>(null); + case COUNT: + return new CountNonNull<>(boundTerm(struct, caseSensitive)); + case MAX: + return new MaxAggregate<>(boundTerm(struct, caseSensitive)); + case MIN: + return new MinAggregate<>(boundTerm(struct, caseSensitive)); + default: + throw new UnsupportedOperationException("Unsupported aggregate type: " + op()); } } + + private BoundTerm boundTerm(Types.StructType struct, boolean caseSensitive) { + Preconditions.checkArgument(term() != null, "Invalid aggregate term: null"); + return term().bind(struct, caseSensitive); + } } diff --git a/api/src/main/java/org/apache/iceberg/expressions/ValueAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/ValueAggregate.java new file mode 100644 index 000000000000..b2e68136dd46 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/ValueAggregate.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import org.apache.iceberg.DataFile; +import org.apache.iceberg.StructLike; + +class ValueAggregate extends BoundAggregate { + private final SingleValueStruct valueStruct = new SingleValueStruct(); + + protected ValueAggregate(Operation op, BoundTerm term) { + super(op, term); + } + + @Override + public T eval(StructLike struct) { + return term().eval(struct); + } + + public T eval(DataFile file) { + valueStruct.setValue(evaluateRef(file)); + return term().eval(valueStruct); + } + + protected Object evaluateRef(DataFile file) { + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement eval(DataFile)"); + } + + /** Used to pass a referenced value through term evaluation. */ + private static class SingleValueStruct implements StructLike { + private Object value; + + private void setValue(Object value) { + this.value = value; + } + + @Override + public int size() { + return 1; + } + + @Override + @SuppressWarnings("unchecked") + public T get(int pos, Class javaClass) { + return (T) value; + } + + @Override + public void set(int pos, T value1) { + throw new UnsupportedOperationException("Cannot update a read-only struct"); + } + } +} diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestAggregateEvaluator.java b/api/src/test/java/org/apache/iceberg/expressions/TestAggregateEvaluator.java new file mode 100644 index 000000000000..bd65e041f9df --- /dev/null +++ b/api/src/test/java/org/apache/iceberg/expressions/TestAggregateEvaluator.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import static org.apache.iceberg.types.Conversions.toByteBuffer; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.util.List; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.TestHelpers.Row; +import org.apache.iceberg.TestHelpers.TestDataFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Types.IntegerType; +import org.apache.iceberg.types.Types.StringType; +import org.junit.Assert; +import org.junit.Test; + +public class TestAggregateEvaluator { + private static final Schema SCHEMA = + new Schema( + required(1, "id", IntegerType.get()), + optional(2, "no_stats", IntegerType.get()), + optional(3, "all_nulls", StringType.get()), + optional(4, "some_nulls", StringType.get())); + + private static final DataFile FILE = + new TestDataFile( + "file.avro", + Row.of(), + 50, + // any value counts, including nulls + ImmutableMap.of(1, 50L, 3, 50L, 4, 50L), + // null value counts + ImmutableMap.of(1, 10L, 3, 50L, 4, 10L), + // nan value counts + null, + // lower bounds + ImmutableMap.of(1, toByteBuffer(IntegerType.get(), 33)), + // upper bounds + ImmutableMap.of(1, toByteBuffer(IntegerType.get(), 2345))); + + private static final DataFile MISSING_SOME_NULLS_STATS_1 = + new TestDataFile( + "file_2.avro", + Row.of(), + 20, + // any value counts, including nulls + ImmutableMap.of(1, 20L, 3, 20L), + // null value counts + ImmutableMap.of(1, 0L, 3, 20L), + // nan value counts + null, + // lower bounds + ImmutableMap.of(1, toByteBuffer(IntegerType.get(), 33)), + // upper bounds + ImmutableMap.of(1, toByteBuffer(IntegerType.get(), 100))); + + private static final DataFile MISSING_SOME_NULLS_STATS_2 = + new TestDataFile( + "file_3.avro", + Row.of(), + 20, + // any value counts, including nulls + ImmutableMap.of(1, 20L, 3, 20L), + // null value counts + ImmutableMap.of(1, 20L, 3, 20L), + // nan value counts + null, + // lower bounds + ImmutableMap.of(1, toByteBuffer(IntegerType.get(), -33)), + // upper bounds + ImmutableMap.of(1, toByteBuffer(IntegerType.get(), 3333))); + + private static final DataFile[] dataFiles = { + FILE, MISSING_SOME_NULLS_STATS_1, MISSING_SOME_NULLS_STATS_2 + }; + + @Test + public void testIntAggregate() { + List list = + ImmutableList.of( + Expressions.countStar(), + Expressions.count("id"), + Expressions.max("id"), + Expressions.min("id")); + AggregateEvaluator aggregateEvaluator = AggregateEvaluator.create(SCHEMA, list); + + for (DataFile dataFile : dataFiles) { + aggregateEvaluator.update(dataFile); + } + + Assert.assertTrue(aggregateEvaluator.allAggregatorsValid()); + StructLike result = aggregateEvaluator.result(); + Object[] expected = {90L, 60L, 3333, -33}; + assertEvaluatorResult(result, expected); + } + + @Test + public void testAllNulls() { + List list = + ImmutableList.of( + Expressions.countStar(), + Expressions.count("all_nulls"), + Expressions.max("all_nulls"), + Expressions.min("all_nulls")); + AggregateEvaluator aggregateEvaluator = AggregateEvaluator.create(SCHEMA, list); + + for (DataFile dataFile : dataFiles) { + aggregateEvaluator.update(dataFile); + } + + Assert.assertTrue(aggregateEvaluator.allAggregatorsValid()); + StructLike result = aggregateEvaluator.result(); + Object[] expected = {90L, 0L, null, null}; + assertEvaluatorResult(result, expected); + } + + @Test + public void testSomeNulls() { + List list = + ImmutableList.of( + Expressions.countStar(), + Expressions.count("some_nulls"), + Expressions.max("some_nulls"), + Expressions.min("some_nulls")); + AggregateEvaluator aggregateEvaluator = AggregateEvaluator.create(SCHEMA, list); + for (DataFile dataFile : dataFiles) { + aggregateEvaluator.update(dataFile); + } + + Assert.assertFalse(aggregateEvaluator.allAggregatorsValid()); + StructLike result = aggregateEvaluator.result(); + Object[] expected = {90L, null, null, null}; + assertEvaluatorResult(result, expected); + } + + @Test + public void testNoStats() { + List list = + ImmutableList.of( + Expressions.countStar(), + Expressions.count("no_stats"), + Expressions.max("no_stats"), + Expressions.min("no_stats")); + AggregateEvaluator aggregateEvaluator = AggregateEvaluator.create(SCHEMA, list); + for (DataFile dataFile : dataFiles) { + aggregateEvaluator.update(dataFile); + } + + Assert.assertFalse(aggregateEvaluator.allAggregatorsValid()); + StructLike result = aggregateEvaluator.result(); + Object[] expected = {90L, null, null, null}; + assertEvaluatorResult(result, expected); + } + + private void assertEvaluatorResult(StructLike result, Object[] expected) { + Object[] actual = new Object[result.size()]; + for (int i = 0; i < result.size(); i++) { + actual[i] = result.get(i, Object.class); + } + + Assert.assertEquals("equals", expected, actual); + } +} diff --git a/core/src/main/java/org/apache/iceberg/BaseTableScan.java b/core/src/main/java/org/apache/iceberg/BaseTableScan.java index 317e50e22e5c..f9399041a010 100644 --- a/core/src/main/java/org/apache/iceberg/BaseTableScan.java +++ b/core/src/main/java/org/apache/iceberg/BaseTableScan.java @@ -47,4 +47,9 @@ public CloseableIterable planTasks() { return TableScanUtil.planTasks( splitFiles, targetSplitSize(), splitLookback(), splitOpenFileCost()); } + + @Override + public TableScan withColStats() { + return newRefinedScan(table(), tableSchema(), context().withColStats(true)); + } } diff --git a/core/src/main/java/org/apache/iceberg/TableScanContext.java b/core/src/main/java/org/apache/iceberg/TableScanContext.java index 6a3c7cc6e93e..d938f16db1cb 100644 --- a/core/src/main/java/org/apache/iceberg/TableScanContext.java +++ b/core/src/main/java/org/apache/iceberg/TableScanContext.java @@ -374,4 +374,21 @@ TableScanContext reportWith(MetricsReporter reporter) { fromSnapshotInclusive, reporter); } + + TableScanContext withColStats(boolean stats) { + return new TableScanContext( + snapshotId, + rowFilter, + ignoreResiduals, + caseSensitive, + stats, + projectedSchema, + selectedColumns, + options, + fromSnapshotId, + toSnapshotId, + planExecutor, + fromSnapshotInclusive, + metricsReporter); + } } diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java index b4cd6dc04616..779300931188 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java @@ -22,6 +22,8 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; +import java.util.List; +import java.util.Locale; import java.util.Map; import org.apache.iceberg.AssertHelpers; import org.apache.iceberg.RowDelta; @@ -31,10 +33,12 @@ import org.apache.iceberg.exceptions.CommitStateUnknownException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.spark.source.SparkTable; import org.apache.iceberg.spark.source.TestSparkCatalog; import org.apache.spark.SparkException; import org.apache.spark.sql.connector.catalog.Identifier; +import org.junit.Assert; import org.junit.Test; import org.junit.runners.Parameterized; @@ -118,4 +122,32 @@ public void testCommitUnknownException() { ImmutableList.of(row(1, "hr", "c1"), row(3, "hr", "c1")), sql("SELECT * FROM %s ORDER BY id", "dummy_catalog.default.table")); } + + @Test + public void testAggregatePushDownInMergeOnReadDelete() { + createAndInitTable("id LONG, data INT"); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + + sql("DELETE FROM %s WHERE data = 1111", tableName); + String select = "SELECT max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)".toLowerCase(Locale.ROOT)) + || explainString.contains("min(data)".toLowerCase(Locale.ROOT)) + || explainString.contains("count(data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "min/max/count not pushed down for deleted", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6666, 2222, 5L}); + assertEquals("min/max/count push down", expected, actual); + } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java index a2dcfed96945..a44929aa30ab 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java @@ -232,4 +232,13 @@ public boolean preserveDataGrouping() { .defaultValue(SparkSQLProperties.PRESERVE_DATA_GROUPING_DEFAULT) .parse(); } + + public boolean aggregatePushDownEnabled() { + return confParser + .booleanConf() + .option(SparkReadOptions.AGGREGATE_PUSH_DOWN_ENABLED) + .sessionConf(SparkSQLProperties.AGGREGATE_PUSH_DOWN_ENABLED) + .defaultValue(SparkSQLProperties.AGGREGATE_PUSH_DOWN_ENABLED_DEFAULT) + .parse(); + } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java index 1007f634ec98..60b2e85f00e0 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java @@ -90,4 +90,6 @@ private SparkReadOptions() {} public static final String VERSION_AS_OF = "versionAsOf"; public static final String TIMESTAMP_AS_OF = "timestampAsOf"; + + public static final String AGGREGATE_PUSH_DOWN_ENABLED = "aggregate-push-down-enabled"; } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index 01e1aa0047c6..a5484d26c17f 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -48,6 +48,11 @@ private SparkSQLProperties() {} "spark.sql.iceberg.planning.preserve-data-grouping"; public static final boolean PRESERVE_DATA_GROUPING_DEFAULT = false; + // Controls whether to push down aggregate (MAX/MIN/COUNT) to Iceberg + public static final String AGGREGATE_PUSH_DOWN_ENABLED = + "spark.sql.iceberg.aggregate-push-down-enabled"; + public static final boolean AGGREGATE_PUSH_DOWN_ENABLED_DEFAULT = true; + // Controls write distribution mode public static final String DISTRIBUTION_MODE = "spark.sql.iceberg.distribution-mode"; } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkAggregates.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkAggregates.java new file mode 100644 index 000000000000..15ea53495b8a --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkAggregates.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; + +public class SparkAggregates { + + private SparkAggregates() {} + + private static final Map, Operation> AGGREGATES = + ImmutableMap., Operation>builder() + .put(Count.class, Operation.COUNT) + .put(CountStar.class, Operation.COUNT_STAR) + .put(Max.class, Operation.MAX) + .put(Min.class, Operation.MIN) + .build(); + + public static Expression convert(AggregateFunc aggregate) { + Operation op = AGGREGATES.get(aggregate.getClass()); + if (op != null) { + switch (op) { + case COUNT: + Count countAgg = (Count) aggregate; + assert (countAgg.column() instanceof NamedReference); + return Expressions.count(SparkUtil.toColumnName((NamedReference) countAgg.column())); + case COUNT_STAR: + return Expressions.countStar(); + case MAX: + Max maxAgg = (Max) aggregate; + assert (maxAgg.column() instanceof NamedReference); + return Expressions.max(SparkUtil.toColumnName((NamedReference) maxAgg.column())); + case MIN: + Min minAgg = (Min) aggregate; + assert (minAgg.column() instanceof NamedReference); + return Expressions.min(SparkUtil.toColumnName((NamedReference) minAgg.column())); + } + } + + throw new UnsupportedOperationException("Unsupported aggregate: " + aggregate); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java new file mode 100644 index 000000000000..8d3b6b7bdea7 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Arrays; +import java.util.stream.Collectors; +import org.apache.iceberg.Table; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.LocalScan; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +class SparkLocalScan implements LocalScan { + + private final Table table; + private final StructType readSchema; + private final InternalRow[] rows; + + SparkLocalScan(Table table, StructType readSchema, InternalRow[] rows) { + this.table = table; + this.readSchema = readSchema; + this.rows = rows; + } + + @Override + public InternalRow[] rows() { + return rows; + } + + @Override + public StructType readSchema() { + return readSchema; + } + + @Override + public String description() { + String fields = + Arrays.stream(readSchema.fields()).map(StructField::name).collect(Collectors.joining(", ")); + return String.format("%s [%s]", table, fields); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java index dcd89b15a22a..1bc751e30b86 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -18,35 +18,51 @@ */ package org.apache.iceberg.spark.source; +import java.io.IOException; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.apache.iceberg.BaseTable; import org.apache.iceberg.BatchScan; +import org.apache.iceberg.FileScanTask; import org.apache.iceberg.IncrementalAppendScan; import org.apache.iceberg.IncrementalChangelogScan; import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.MetricsModes; import org.apache.iceberg.Schema; import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StructLike; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.expressions.AggregateEvaluator; import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.BoundAggregate; import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.expressions.ExpressionUtil; import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkFilters; import org.apache.iceberg.spark.SparkReadConf; import org.apache.iceberg.spark.SparkReadOptions; import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsPushDownAggregates; import org.apache.spark.sql.connector.read.SupportsPushDownFilters; import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; import org.apache.spark.sql.connector.read.SupportsReportStatistics; @@ -59,12 +75,15 @@ public class SparkScanBuilder implements ScanBuilder, + SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics { private static final Logger LOG = LoggerFactory.getLogger(SparkScanBuilder.class); private static final Filter[] NO_FILTERS = new Filter[0]; + private StructType pushedAggregateSchema; + private Scan localScan; private final SparkSession spark; private final Table table; @@ -153,6 +172,147 @@ public Filter[] pushedFilters() { return pushedFilters; } + @Override + public boolean pushAggregation(Aggregation aggregation) { + if (!canPushDownAggregation(aggregation)) { + return false; + } + + AggregateEvaluator aggregateEvaluator; + List> expressions = + Lists.newArrayListWithExpectedSize(aggregation.aggregateExpressions().length); + + for (AggregateFunc aggregateFunc : aggregation.aggregateExpressions()) { + try { + Expression expr = SparkAggregates.convert(aggregateFunc); + if (expr != null) { + Expression bound = Binder.bind(schema.asStruct(), expr, caseSensitive); + expressions.add((BoundAggregate) bound); + } + } catch (UnsupportedOperationException e) { + LOG.info( + "Skipping aggregate pushdown: AggregateFunc {} can't be converted to iceberg Expression", + aggregateFunc, + e); + return false; + } catch (IllegalArgumentException e) { + LOG.info("Skipping aggregate pushdown: Bind failed for AggregateFunc {}", aggregateFunc, e); + return false; + } + } + + aggregateEvaluator = AggregateEvaluator.create(expressions); + + if (!metricsModeSupportsAggregatePushDown(aggregateEvaluator.aggregates())) { + return false; + } + + TableScan scan = table.newScan().withColStats(); + Snapshot snapshot = readSnapshot(); + if (snapshot == null) { + LOG.info("Skipping aggregate pushdown: table snapshot is null"); + return false; + } + scan = scan.useSnapshot(snapshot.snapshotId()); + scan = configureSplitPlanning(scan); + scan = scan.filter(filterExpression()); + + try (CloseableIterable fileScanTasks = scan.planFiles()) { + List tasks = ImmutableList.copyOf(fileScanTasks); + for (FileScanTask task : tasks) { + if (!task.deletes().isEmpty()) { + LOG.info("Skipping aggregate pushdown: detected row level deletes"); + return false; + } + + aggregateEvaluator.update(task.file()); + } + } catch (IOException e) { + LOG.info("Skipping aggregate pushdown: ", e); + return false; + } + + if (!aggregateEvaluator.allAggregatorsValid()) { + return false; + } + + pushedAggregateSchema = + SparkSchemaUtil.convert(new Schema(aggregateEvaluator.resultType().fields())); + InternalRow[] pushedAggregateRows = new InternalRow[1]; + StructLike structLike = aggregateEvaluator.result(); + pushedAggregateRows[0] = + new StructInternalRow(aggregateEvaluator.resultType()).setStruct(structLike); + localScan = new SparkLocalScan(table, pushedAggregateSchema, pushedAggregateRows); + + return true; + } + + private boolean canPushDownAggregation(Aggregation aggregation) { + if (!(table instanceof BaseTable)) { + return false; + } + + if (!readConf.aggregatePushDownEnabled()) { + return false; + } + + // If group by expression is the same as the partition, the statistics information can still + // be used to calculate min/max/count, will enable aggregate push down in next phase. + // TODO: enable aggregate push down for partition col group by expression + if (aggregation.groupByExpressions().length > 0) { + LOG.info("Skipping aggregate pushdown: group by aggregation push down is not supported"); + return false; + } + + return true; + } + + private Snapshot readSnapshot() { + Snapshot snapshot; + if (readConf.snapshotId() != null) { + snapshot = table.snapshot(readConf.snapshotId()); + } else { + snapshot = table.currentSnapshot(); + } + + return snapshot; + } + + private boolean metricsModeSupportsAggregatePushDown(List> aggregates) { + MetricsConfig config = MetricsConfig.forTable(table); + for (BoundAggregate aggregate : aggregates) { + String colName = aggregate.columnName(); + if (!colName.equals("*")) { + MetricsModes.MetricsMode mode = config.columnMode(colName); + if (mode instanceof MetricsModes.None) { + LOG.info("Skipping aggregate pushdown: No metrics for column {}", colName); + return false; + } else if (mode instanceof MetricsModes.Counts) { + if (aggregate.op() == Expression.Operation.MAX + || aggregate.op() == Expression.Operation.MIN) { + LOG.info( + "Skipping aggregate pushdown: Cannot produce min or max from count for column {}", + colName); + return false; + } + } else if (mode instanceof MetricsModes.Truncate) { + // lower_bounds and upper_bounds may be truncated, so disable push down + if (aggregate.type().typeId() == Type.TypeID.STRING) { + if (aggregate.op() == Expression.Operation.MAX + || aggregate.op() == Expression.Operation.MIN) { + LOG.info( + "Skipping aggregate pushdown: Cannot produce min or max from truncated values for column {}", + colName); + return false; + } + } + } + } + } + + return true; + } + @Override public void pruneColumns(StructType requestedSchema) { StructType requestedProjection = @@ -188,6 +348,14 @@ private Schema schemaWithMetadataColumns() { @Override public Scan build() { + if (localScan != null) { + return localScan; + } else { + return buildBatchScan(); + } + } + + private Scan buildBatchScan() { Long snapshotId = readConf.snapshotId(); Long asOfTimestamp = readConf.asOfTimestamp(); String branch = readConf.branch(); diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java new file mode 100644 index 000000000000..9cecf89bba2a --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java @@ -0,0 +1,680 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.spark.sql.SparkSession; +import org.junit.After; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestAggregatePushDown extends SparkCatalogTestBase { + + public TestAggregatePushDown( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @BeforeClass + public static void startMetastoreAndSpark() { + SparkTestBase.metastore = new TestHiveMetastore(); + metastore.start(); + SparkTestBase.hiveConf = metastore.hiveConf(); + + SparkTestBase.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.sql.iceberg.aggregate_pushdown", "true") + .enableHiveSupport() + .getOrCreate(); + + SparkTestBase.catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDifferentDataTypesAggregatePushDownInPartitionedTable() { + testDifferentDataTypesAggregatePushDown(true); + } + + @Test + public void testDifferentDataTypesAggregatePushDownInNonPartitionedTable() { + testDifferentDataTypesAggregatePushDown(false); + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private void testDifferentDataTypesAggregatePushDown(boolean hasPartitionCol) { + String createTable; + if (hasPartitionCol) { + createTable = + "CREATE TABLE %s (id LONG, int_data INT, boolean_data BOOLEAN, float_data FLOAT, double_data DOUBLE, " + + "decimal_data DECIMAL(14, 2), binary_data binary) USING iceberg PARTITIONED BY (id)"; + } else { + createTable = + "CREATE TABLE %s (id LONG, int_data INT, boolean_data BOOLEAN, float_data FLOAT, double_data DOUBLE, " + + "decimal_data DECIMAL(14, 2), binary_data binary) USING iceberg"; + } + + sql(createTable, tableName); + sql( + "INSERT INTO TABLE %s VALUES " + + "(1, null, false, null, null, 11.11, X'1111')," + + " (1, null, true, 2.222, 2.222222, 22.22, X'2222')," + + " (2, 33, false, 3.333, 3.333333, 33.33, X'3333')," + + " (2, 44, true, null, 4.444444, 44.44, X'4444')," + + " (3, 55, false, 5.555, 5.555555, 55.55, X'5555')," + + " (3, null, true, null, 6.666666, 66.66, null) ", + tableName); + + String select = + "SELECT count(*), max(id), min(id), count(id), " + + "max(int_data), min(int_data), count(int_data), " + + "max(boolean_data), min(boolean_data), count(boolean_data), " + + "max(float_data), min(float_data), count(float_data), " + + "max(double_data), min(double_data), count(double_data), " + + "max(decimal_data), min(decimal_data), count(decimal_data), " + + "max(binary_data), min(binary_data), count(binary_data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("count(*)".toLowerCase(Locale.ROOT)) + && explainString.contains("max(id)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(id)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(id)".toLowerCase(Locale.ROOT)) + && explainString.contains("max(int_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(int_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(int_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("max(boolean_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(boolean_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(boolean_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("max(float_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(float_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(float_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("max(double_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(double_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(double_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("max(decimal_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(decimal_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(decimal_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("max(binary_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(binary_data)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(binary_data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + 6L, + 3L, + 1L, + 6L, + 55, + 33, + 3L, + true, + false, + 6L, + 5.555f, + 2.222f, + 3L, + 6.666666, + 2.222222, + 5L, + new BigDecimal("66.66"), + new BigDecimal("11.11"), + 6L, + new byte[] {85, 85}, + new byte[] {17, 17}, + 5L + }); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testDateAndTimestampWithPartition() { + sql( + "CREATE TABLE %s (id bigint, data string, d date, ts timestamp) USING iceberg PARTITIONED BY (id)", + tableName); + sql( + "INSERT INTO %s VALUES (1, '1', date('2021-11-10'), null)," + + "(1, '2', date('2021-11-11'), timestamp('2021-11-11 22:22:22')), " + + "(2, '3', date('2021-11-12'), timestamp('2021-11-12 22:22:22')), " + + "(2, '4', date('2021-11-13'), timestamp('2021-11-13 22:22:22')), " + + "(3, '5', null, timestamp('2021-11-14 22:22:22')), " + + "(3, '6', date('2021-11-14'), null)", + tableName); + String select = "SELECT max(d), min(d), count(d), max(ts), min(ts), count(ts) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(d)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(d)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(d)".toLowerCase(Locale.ROOT)) + && explainString.contains("max(ts)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(ts)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(ts)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + Date.valueOf("2021-11-14"), + Date.valueOf("2021-11-10"), + 5L, + Timestamp.valueOf("2021-11-14 22:22:22.0"), + Timestamp.valueOf("2021-11-11 22:22:22.0"), + 4L + }); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testAggregateNotPushDownIfOneCantPushDown() { + sql("CREATE TABLE %s (id LONG, data DOUBLE) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + String select = "SELECT COUNT(data), SUM(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("count(data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, 23331.0}); + assertEquals("expected and actual should equal", expected, actual); + } + + @Test + public void testAggregatePushDownWithMetricsMode() { + sql("CREATE TABLE %s (id LONG, data DOUBLE) USING iceberg", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "id", "counts"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "data", "none"); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666)", + tableName); + + String select1 = "SELECT COUNT(data) FROM %s"; + + List explain1 = sql("EXPLAIN " + select1, tableName); + String explainString1 = explain1.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString1.contains("count(data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + // count(data) is not pushed down because the metrics mode is `none` + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual1 = sql(select1, tableName); + List expected1 = Lists.newArrayList(); + expected1.add(new Object[] {6L}); + assertEquals("expected and actual should equal", expected1, actual1); + + String select2 = "SELECT COUNT(id) FROM %s"; + List explain2 = sql("EXPLAIN " + select2, tableName); + String explainString2 = explain2.get(0)[0].toString(); + if (explainString2.contains("count(id)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + // count(id) is pushed down because the metrics mode is `counts` + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual2 = sql(select2, tableName); + List expected2 = Lists.newArrayList(); + expected2.add(new Object[] {6L}); + assertEquals("expected and actual should equal", expected2, actual2); + + String select3 = "SELECT COUNT(id), MAX(id) FROM %s"; + explainContainsPushDownAggregates = false; + List explain3 = sql("EXPLAIN " + select3, tableName); + String explainString3 = explain3.get(0)[0].toString(); + if (explainString3.contains("count(id)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + // COUNT(id), MAX(id) are not pushed down because MAX(id) is not pushed down (metrics mode is + // `counts`) + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual3 = sql(select3, tableName); + List expected3 = Lists.newArrayList(); + expected3.add(new Object[] {6L, 3L}); + assertEquals("expected and actual should equal", expected3, actual3); + } + + @Test + public void testAggregateNotPushDownForStringType() { + sql("CREATE TABLE %s (id LONG, data STRING) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, '1111'), (1, '2222'), (2, '3333'), (2, '4444'), (3, '5555'), (3, '6666') ", + tableName); + + String select1 = "SELECT MAX(id), MAX(data) FROM %s"; + + List explain1 = sql("EXPLAIN " + select1, tableName); + String explainString1 = explain1.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString1.contains("max(id)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual1 = sql(select1, tableName); + List expected1 = Lists.newArrayList(); + expected1.add(new Object[] {3L, "6666"}); + assertEquals("expected and actual should equal", expected1, actual1); + + String select2 = "SELECT COUNT(data) FROM %s"; + List explain2 = sql("EXPLAIN " + select2, tableName); + String explainString2 = explain2.get(0)[0].toString(); + if (explainString2.contains("count(data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual2 = sql(select2, tableName); + List expected2 = Lists.newArrayList(); + expected2.add(new Object[] {6L}); + assertEquals("min/max/count push down", expected2, actual2); + + explainContainsPushDownAggregates = false; + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DEFAULT_WRITE_METRICS_MODE, "full"); + String select3 = "SELECT count(data), max(data) FROM %s"; + List explain3 = sql("EXPLAIN " + select3, tableName); + String explainString3 = explain3.get(0)[0].toString(); + if (explainString3.contains("count(data)".toLowerCase(Locale.ROOT)) + && explainString3.contains("max(data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual3 = sql(select3, tableName); + List expected3 = Lists.newArrayList(); + expected3.add(new Object[] {6L, "6666"}); + assertEquals("expected and actual should equal", expected3, actual3); + } + + @Test + public void testAggregatePushDownWithDataFilter() { + testAggregatePushDownWithFilter(false); + } + + @Test + public void testAggregatePushDownWithPartitionFilter() { + testAggregatePushDownWithFilter(true); + } + + private void testAggregatePushDownWithFilter(boolean partitionFilerOnly) { + String createTable; + if (!partitionFilerOnly) { + createTable = "CREATE TABLE %s (id LONG, data INT) USING iceberg"; + } else { + createTable = "CREATE TABLE %s (id LONG, data INT) USING iceberg PARTITIONED BY (id)"; + } + + sql(createTable, tableName); + + sql( + "INSERT INTO TABLE %s VALUES" + + " (1, 11)," + + " (1, 22)," + + " (2, 33)," + + " (2, 44)," + + " (3, 55)," + + " (3, 66) ", + tableName); + + String select = "SELECT MIN(data) FROM %s WHERE id > 1"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("min(data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + if (!partitionFilerOnly) { + // Filters are not completely pushed down, we can't push down aggregates + Assert.assertFalse( + "explain should not contain the pushed down aggregates", + explainContainsPushDownAggregates); + } else { + // Filters are not completely pushed down, we can push down aggregates + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + } + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {33}); + assertEquals("expected and actual should equal", expected, actual); + } + + @Test + public void testAggregateWithComplexType() { + sql("CREATE TABLE %s (id INT, complex STRUCT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))," + + "(2, named_struct(\"c1\", 2, \"c2\", \"v2\")), (3, null)", + tableName); + String select1 = "SELECT count(complex), count(id) FROM %s"; + List explain = sql("EXPLAIN " + select1, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("count(complex)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "count not pushed down for complex types", explainContainsPushDownAggregates); + + List actual = sql(select1, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {2L, 3L}); + assertEquals("count not push down", actual, expected); + + String select2 = "SELECT max(complex) FROM %s"; + explain = sql("EXPLAIN " + select2, tableName); + explainString = explain.get(0)[0].toString(); + explainContainsPushDownAggregates = false; + if (explainString.contains("max(complex)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse("max not pushed down for complex types", explainContainsPushDownAggregates); + } + + @Test + public void testAggregatePushDownInDeleteCopyOnWrite() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + sql("DELETE FROM %s WHERE data = 1111", tableName); + String select = "SELECT max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(data)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue("min/max/count pushed down for deleted", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6666, 2222, 5L}); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testAggregatePushDownForTimeTravel() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + List expected1 = sql("SELECT count(id) FROM %s", tableName); + + sql("INSERT INTO %s VALUES (4, 7777), (5, 8888)", tableName); + List expected2 = sql("SELECT count(id) FROM %s", tableName); + + List explain1 = + sql("EXPLAIN SELECT count(id) FROM %s VERSION AS OF %s", tableName, snapshotId); + String explainString1 = explain1.get(0)[0].toString(); + boolean explainContainsPushDownAggregates1 = false; + if (explainString1.contains("count(id)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates1 = true; + } + Assert.assertTrue("count pushed down", explainContainsPushDownAggregates1); + + List actual1 = + sql("SELECT count(id) FROM %s VERSION AS OF %s", tableName, snapshotId); + assertEquals("count push down", expected1, actual1); + + List explain2 = sql("EXPLAIN SELECT count(id) FROM %s", tableName); + String explainString2 = explain2.get(0)[0].toString(); + boolean explainContainsPushDownAggregates2 = false; + if (explainString2.contains("count(id)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates2 = true; + } + + Assert.assertTrue("count pushed down", explainContainsPushDownAggregates2); + + List actual2 = sql("SELECT count(id) FROM %s", tableName); + assertEquals("count push down", expected2, actual2); + } + + @Test + public void testAllNull() { + sql("CREATE TABLE %s (id int, data int) USING iceberg PARTITIONED BY (id)", tableName); + sql( + "INSERT INTO %s VALUES (1, null)," + + "(1, null), " + + "(2, null), " + + "(2, null), " + + "(3, null), " + + "(3, null)", + tableName); + String select = "SELECT count(*), max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(data)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, null, null, 0L}); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testAllNaN() { + sql("CREATE TABLE %s (id int, data float) USING iceberg PARTITIONED BY (id)", tableName); + sql( + "INSERT INTO %s VALUES (1, float('nan'))," + + "(1, float('nan')), " + + "(2, float('nan')), " + + "(2, float('nan')), " + + "(3, float('nan')), " + + "(3, float('nan'))", + tableName); + String select = "SELECT count(*), max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)".toLowerCase(Locale.ROOT)) + || explainString.contains("min(data)".toLowerCase(Locale.ROOT)) + || explainString.contains("count(data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, Float.NaN, Float.NaN, 6L}); + assertEquals("expected and actual should equal", expected, actual); + } + + @Test + public void testNaN() { + sql("CREATE TABLE %s (id int, data float) USING iceberg PARTITIONED BY (id)", tableName); + sql( + "INSERT INTO %s VALUES (1, float('nan'))," + + "(1, float('nan')), " + + "(2, 2), " + + "(2, float('nan')), " + + "(3, float('nan')), " + + "(3, 1)", + tableName); + String select = "SELECT count(*), max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)".toLowerCase(Locale.ROOT)) + || explainString.contains("min(data)".toLowerCase(Locale.ROOT)) + || explainString.contains("count(data)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, Float.NaN, 1.0F, 6L}); + assertEquals("expected and actual should equal", expected, actual); + } + + @Test + public void testInfinity() { + sql( + "CREATE TABLE %s (id int, data1 float, data2 double, data3 double) USING iceberg PARTITIONED BY (id)", + tableName); + sql( + "INSERT INTO %s VALUES (1, float('-infinity'), double('infinity'), 1.23), " + + "(1, float('-infinity'), double('infinity'), -1.23), " + + "(1, float('-infinity'), double('infinity'), double('infinity')), " + + "(1, float('-infinity'), double('infinity'), 2.23), " + + "(1, float('-infinity'), double('infinity'), double('-infinity')), " + + "(1, float('-infinity'), double('infinity'), -2.23)", + tableName); + String select = + "SELECT count(*), max(data1), min(data1), count(data1), max(data2), min(data2), count(data2), max(data3), min(data3), count(data3) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data1)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(data1)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(data1)".toLowerCase(Locale.ROOT)) + && explainString.contains("max(data2)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(data2)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(data2)".toLowerCase(Locale.ROOT)) + && explainString.contains("max(data3)".toLowerCase(Locale.ROOT)) + && explainString.contains("min(data3)".toLowerCase(Locale.ROOT)) + && explainString.contains("count(data3)".toLowerCase(Locale.ROOT))) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + 6L, + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + 6L, + Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY, + 6L, + Double.POSITIVE_INFINITY, + Double.NEGATIVE_INFINITY, + 6L + }); + assertEquals("min/max/count push down", expected, actual); + } +}