diff --git a/api/src/main/java/org/apache/iceberg/expressions/AggregateUtil.java b/api/src/main/java/org/apache/iceberg/expressions/AggregateUtil.java new file mode 100644 index 000000000000..1273e0f8c733 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/AggregateUtil.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import java.util.List; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; + +/** Aggregate utility methods. */ +public class AggregateUtil { + private AggregateUtil() {} + + /** + * Create a NestedField for this Aggregate Expression. This NestedField is used to build the + * pushed down aggregate schema. + * + *

e.g. SELECT COUNT(*), MAX(col1), MIN(col1), MAX(col2), MIN(col3) FROM table; + * + *

Suppose the table schema is Schema(Types.NestedField.required(1, "col1", + * Types.IntegerType.get()), Types.NestedField.required(2, "col2", Types.StringType.get()), + * Types.NestedField.required(3, "col3", Types.StringType.get())); + * + *

The returned NestedField for the aggregates are Types.NestedField.required(1, COUNT(*), + * Types.LongType.get()) Types.NestedField.required(2, MAX(col1), Types.IntegerType.get()) + * Types.NestedField.required(3, MIN(col1), Types.IntegerType.get()) Types.NestedField.required(4, + * MAX(col2), Types.StringType.get()) Types.NestedField.required(5, MIN(col3), + * Types.StringType.get()) + */ + public static Types.NestedField buildAggregateNestedField(BoundAggregate aggregate, int index) { + return aggregate.nestedField(index); + } + + /** + * Returns the column name this aggregate function is on. e.g. SELECT Max(col3) FROM table; This + * method returns col3 + */ + public static String getAggregateColumnName(BoundAggregate aggregate) { + return aggregate.columnName(); + } + + /** + * Returns the data type of this Aggregate Expression. The data type for COUNT is always Long. The + * data type for MAX and MIX are the same as the data type of the column this aggregate is applied + * on. + */ + public static Type getAggregateType(BoundAggregate aggregate) { + return aggregate.type(); + } + + /** + * Returns the index of this Aggregate column in table schema. e.g. SELECT Max(col3) FROM table; + * Suppose the table has columns (col1, col2, col3), this method returns 2. + */ + public static int columnIndexInTableSchema( + BoundAggregate aggregate, Table table, boolean caseSensitive) { + List columns = table.schema().columns(); + for (int i = 0; i < columns.size(); i++) { + if (aggregate.columnName().equals("*")) { + return -1; + } + if (caseSensitive) { + if (aggregate.columnName().equals(columns.get(i).name())) { + return i; + } + } else { + if (aggregate.columnName().equalsIgnoreCase(columns.get(i).name())) { + return i; + } + } + } + throw new ValidationException( + "Aggregate is on an invalid table column %s", aggregate.columnName()); + } +} diff --git a/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java index 650271b3b78a..41b8eaa6278d 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java +++ b/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java @@ -29,7 +29,8 @@ protected BoundAggregate(Operation op, BoundTerm term) { @Override public C eval(StructLike struct) { - throw new UnsupportedOperationException(this.getClass().getName() + " does not implement eval"); + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement eval."); } @Override @@ -37,6 +38,29 @@ public BoundReference ref() { return term().ref(); } + public Types.NestedField nestedField(int index) { + if (op() == Operation.COUNT_STAR) { + return Types.NestedField.required(index, "COUNT(*)", Types.LongType.get()); + } else { + if (term() instanceof BoundReference) { + if (op() == Operation.COUNT) { + return Types.NestedField.required( + index, "COUNT(" + term().ref().name() + ")", Types.LongType.get()); + } else if (op() == Operation.MAX) { + return Types.NestedField.required( + index, "MAX(" + term().ref().name() + ")", term().type()); + } else if (op() == Operation.MIN) { + return Types.NestedField.required( + index, "MIN(" + term().ref().name() + ")", term().type()); + } else { + throw new UnsupportedOperationException(op() + " is not supported."); + } + } else { + throw new UnsupportedOperationException("Aggregate BoundTransform is not supported."); + } + } + } + public Type type() { if (op() == Operation.COUNT || op() == Operation.COUNT_STAR) { return Types.LongType.get(); @@ -44,4 +68,12 @@ public Type type() { return term().type(); } } + + public String columnName() { + if (op() == Operation.COUNT_STAR) { + return "*"; + } else { + return ref().name(); + } + } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java index 85c71827d7ac..0559f38a995f 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java @@ -228,4 +228,15 @@ public Long streamFromTimestamp() { .defaultValue(Long.MIN_VALUE) .parse(); } + + public boolean aggregatePushDown() { + boolean enable = + confParser + .booleanConf() + .option(SparkReadOptions.AGGREGATE_PUSH_DOWN_ENABLED) + .sessionConf(SparkSQLProperties.AGGREGATE_PUSH_DOWN_ENABLED) + .defaultValue(SparkSQLProperties.AGGREGATE_PUSH_DOWN_ENABLED_DEFAULT) + .parse(); + return enable; + } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java index 96e09d70ef65..bc9797a82207 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java @@ -84,4 +84,6 @@ private SparkReadOptions() {} public static final String VERSION_AS_OF = "versionAsOf"; public static final String TIMESTAMP_AS_OF = "timestampAsOf"; + + public static final String AGGREGATE_PUSH_DOWN_ENABLED = "aggregatePushDownEnabled"; } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index fa8bd719f391..17a39336478e 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -42,4 +42,8 @@ private SparkSQLProperties() {} // Controls whether to check the order of fields during writes public static final String CHECK_ORDERING = "spark.sql.iceberg.check-ordering"; public static final boolean CHECK_ORDERING_DEFAULT = true; + + // Controls whether to push down aggregate (MAX/MIN/COUNT) to Iceberg + public static final String AGGREGATE_PUSH_DOWN_ENABLED = "spark.sql.iceberg.aggregate_pushdown"; + public static final boolean AGGREGATE_PUSH_DOWN_ENABLED_DEFAULT = true; } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkAggregates.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkAggregates.java new file mode 100644 index 000000000000..4b8c971380b5 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkAggregates.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; + +public class SparkAggregates { + + private SparkAggregates() {} + + private static final Map, Operation> AGGREGATES = + ImmutableMap., Operation>builder() + .put(Count.class, Operation.COUNT) + .put(CountStar.class, Operation.COUNT_STAR) + .put(Max.class, Operation.MAX) + .put(Min.class, Operation.MIN) + .build(); + + public static Expression convert(AggregateFunc aggregate) { + Operation op = AGGREGATES.get(aggregate.getClass()); + if (op != null) { + switch (op) { + case COUNT: + Count countAgg = (Count) aggregate; + assert (countAgg.column() instanceof NamedReference); + return Expressions.count(SparkUtil.toColumnName((NamedReference) countAgg.column())); + case COUNT_STAR: + return Expressions.countStar(); + case MAX: + Max maxAgg = (Max) aggregate; + assert (maxAgg.column() instanceof NamedReference); + return Expressions.max(SparkUtil.toColumnName((NamedReference) maxAgg.column())); + case MIN: + Min minAgg = (Min) aggregate; + assert (minAgg.column() instanceof NamedReference); + return Expressions.min(SparkUtil.toColumnName((NamedReference) minAgg.column())); + } + } + return null; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java new file mode 100644 index 000000000000..78f6a90458cd --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Arrays; +import java.util.stream.Collectors; +import org.apache.iceberg.Table; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.LocalScan; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +class SparkLocalScan implements LocalScan { + + private final Table table; + private final StructType aggregateSchema; + private final InternalRow[] rows; + + SparkLocalScan(Table table, StructType aggregateSchema, InternalRow[] rows) { + this.table = table; + this.aggregateSchema = aggregateSchema; + this.rows = rows; + } + + @Override + public InternalRow[] rows() { + return rows; + } + + @Override + public StructType readSchema() { + return aggregateSchema; + } + + @Override + public String description() { + String aggregates = + Arrays.stream(aggregateSchema.fields()) + .map(StructField::name) + .collect(Collectors.joining(", ")); + return String.format("%s [aggregates=%s]", table, aggregates); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPushedDownAggregateUtil.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPushedDownAggregateUtil.java new file mode 100644 index 000000000000..5ea4c743cedc --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPushedDownAggregateUtil.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.List; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.MetricsModes; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.AggregateUtil; +import org.apache.iceberg.expressions.BoundAggregate; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.types.Conversions; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; +import scala.collection.JavaConverters; + +/** Helper methods for working with Spark aggregate push down. */ +public class SparkPushedDownAggregateUtil { + static final int LOWER_BOUNDS_COLUMN_INDEX = 0; + static final int UPPER_BOUNDS_COLUMN_INDEX = 1; + static final int RECORD_COUNT_COLUMN_INDEX = 2; + static final int NULL_COUNT_COLUMN_INDEX = 3; + + private SparkPushedDownAggregateUtil() {} + + public static boolean metricsModeSupportsAggregatePushDown( + Table table, List aggregates) { + MetricsConfig config = MetricsConfig.forTable(table); + for (BoundAggregate aggregate : aggregates) { + String colName = AggregateUtil.getAggregateColumnName(aggregate); + if (!colName.equals("*")) { + MetricsModes.MetricsMode mode = config.columnMode(colName); + if (mode.toString().equals("none")) { + return false; + } else if (mode.toString().equals("counts")) { + if (aggregate.op() == Expression.Operation.MAX + || aggregate.op() == Expression.Operation.MIN) { + return false; + } + } else if (mode.toString().contains("truncate")) { + if (AggregateUtil.getAggregateType(aggregate).typeId() == Type.TypeID.STRING) { + if (aggregate.op() == Expression.Operation.MAX + || aggregate.op() == Expression.Operation.MIN) { + return false; + } + } + } + } + } + + return true; + } + + public static InternalRow[] constructInternalRowForPushedDownAggregate( + SparkSession spark, + Table table, + List aggregates, + List indexInTable) { + List valuesInSparkInternalRow = Lists.newArrayList(); + Row[] row = SparkPushedDownAggregateUtil.getStatisticRow(spark, table); + for (int i = 0; i < aggregates.size(); i++) { + BoundAggregate aggregate = aggregates.get(i); + Type type = AggregateUtil.getAggregateType(aggregate); + valuesInSparkInternalRow.add(getAggregateValue(aggregate, row, indexInTable.get(i), type)); + } + + InternalRow[] rows = new InternalRow[1]; + rows[0] = InternalRow.fromSeq(JavaConverters.asScalaBuffer(valuesInSparkInternalRow).toSeq()); + return rows; + } + + private static Row[] getStatisticRow(SparkSession spark, Table table) { + Dataset metadataRows = + SparkTableUtil.loadMetadataTable(spark, table, MetadataTableType.DATA_FILES); + Dataset dataset = + metadataRows.selectExpr( + "lower_bounds", "upper_bounds", "record_count", "null_value_counts"); + return (Row[]) dataset.collect(); + } + + // statisticRows: "lower_bounds", "upper_bounds", "record_count", "null_value_counts" + // here is one example of the rows: + // +---------------------------+--------------------------+------------+------------------------+ + // |lower_bounds |upper_bounds |record_count|null_value_counts | + // +---------------------------+--------------------------+------------+------------------------+ + // |{1 -> a, 2 -> a, 3 -> null}|{1 -> n, 2 -> c, 3 -> vvv}|2 |{1 -> 0, 2 -> 0, 3 -> 0}| + // |{1 -> b, 2 -> l, 3 -> ccc} |{1 -> m, 2 -> l, 3 -> ccc}|1 |{1 -> 0, 2 -> 0, 3 -> 0}| + // |{1 -> a, 2 -> b, 3 -> cc} |{1 -> v, 2 -> b, 3 -> cc} |1 |{1 -> 0, 2 -> 0, 3 -> 0}| + // |{1 -> n, 2 -> m, 3 -> bbb} |{1 -> o, 2 -> n, 3 -> mmm}|2 |{1 -> 0, 2 -> 0, 3 -> 0}| + // +-----------------------------+------------------------+------------+------------------------+ + // + // index is the index of this column in the original table schema + // e.g. MAX (col2), suppose the table schema is col1, col2, col3, then this index is 1. + // this (index + 1) is the key in the map values in lower_bounds/upper_bounds/null_value_counts + private static T getAggregateValue( + BoundAggregate aggregate, Row[] statisticRows, int index, Type type) { + + switch (aggregate.op()) { + case COUNT: + Long count = 0L; + for (int i = 0; i < statisticRows.length; i++) { + count += statisticRows[i].getLong(RECORD_COUNT_COLUMN_INDEX); + } + + Long numOfNulls = getNullValueCount(statisticRows, index); + return (T) (Long.valueOf(count - numOfNulls)); + case COUNT_STAR: + Long countStar = 0L; + for (int i = 0; i < statisticRows.length; i++) { + countStar += statisticRows[i].getLong(RECORD_COUNT_COLUMN_INDEX); + } + + return (T) countStar; + case MAX: + return getMinOrMax(statisticRows, index, type, false); + case MIN: + return getMinOrMax(statisticRows, index, type, true); + default: + throw new UnsupportedOperationException("Invalid aggregate: " + aggregate.op()); + } + } + + private static long getNullValueCount(Row[] statisticRows, int index) { + long numOfNulls = 0L; + for (int i = 0; i < statisticRows.length; i++) { + Long value = (Long) statisticRows[i].getJavaMap(NULL_COUNT_COLUMN_INDEX).get(index + 1); + numOfNulls += value; + } + + return numOfNulls; + } + + @SuppressWarnings({"checkstyle:CyclomaticComplexity", "checkstyle:MethodLength"}) + private static T getMinOrMax(Row[] statisticRows, int index, Type type, boolean isMin) { + T result = null; + boolean isString = false; + boolean isBinary = false; + boolean isDecimal = false; + int columIndex = LOWER_BOUNDS_COLUMN_INDEX; + if (!isMin) { + columIndex = UPPER_BOUNDS_COLUMN_INDEX; + } + + for (int i = 0; i < statisticRows.length; i++) { + byte[] valueInBytes = (byte[]) statisticRows[i].getJavaMap(columIndex).get(index + 1); + if (valueInBytes != null) { + switch (type.typeId()) { + case BOOLEAN: + boolean booleanValue = + Conversions.fromByteBuffer(Types.BooleanType.get(), ByteBuffer.wrap(valueInBytes)); + if (isMin) { + result = (T) Boolean.TRUE; + if (!booleanValue) { + return (T) Boolean.FALSE; + } + } else { + result = (T) Boolean.FALSE; + if (booleanValue) { + return (T) Boolean.TRUE; + } + } + + return result; + case INTEGER: + case DATE: + int intValue = + Conversions.fromByteBuffer(Types.IntegerType.get(), ByteBuffer.wrap(valueInBytes)); + if (isMin) { + if (result == null + || (Literal.of(intValue)).comparator().compare(intValue, (Integer) result) < 0) { + result = (T) Integer.valueOf(intValue); + } + } else { + if (result == null + || (Literal.of(intValue)).comparator().compare(intValue, (Integer) result) > 0) { + result = (T) Integer.valueOf(intValue); + } + } + break; + case LONG: + case TIME: + case TIMESTAMP: + long longValue = + Conversions.fromByteBuffer(Types.LongType.get(), ByteBuffer.wrap(valueInBytes)); + if (isMin) { + if (result == null + || (Literal.of(longValue)).comparator().compare(longValue, (Long) result) < 0) { + result = (T) Long.valueOf(longValue); + } + } else { + if (result == null + || (Literal.of(longValue)).comparator().compare(longValue, (Long) result) > 0) { + result = (T) Long.valueOf(longValue); + } + } + break; + case FLOAT: + float fValue = + Conversions.fromByteBuffer(Types.FloatType.get(), ByteBuffer.wrap(valueInBytes)); + if (isMin) { + if (result == null + || (Literal.of(fValue)).comparator().compare(fValue, (Float) result) < 0) { + result = (T) Float.valueOf(fValue); + } + } else { + if (result == null + || (Literal.of(fValue)).comparator().compare(fValue, (Float) result) > 0) { + result = (T) Float.valueOf(fValue); + } + } + break; + case DOUBLE: + double doubleValue = + Conversions.fromByteBuffer(Types.DoubleType.get(), ByteBuffer.wrap(valueInBytes)); + if (isMin) { + if (result == null + || (Literal.of(doubleValue)).comparator().compare(doubleValue, (Double) result) + < 0) { + result = (T) Double.valueOf(doubleValue); + } + } else { + if (result == null + || (Literal.of(doubleValue)).comparator().compare(doubleValue, (Double) result) + > 0) { + result = (T) Double.valueOf(doubleValue); + } + } + break; + case STRING: + String stringValue = + Conversions.fromByteBuffer(Types.StringType.get(), ByteBuffer.wrap(valueInBytes)) + .toString(); + if (isMin) { + if (result == null + || (Literal.of(stringValue)).comparator().compare(stringValue, (String) result) + < 0) { + result = (T) stringValue; + } + } else { + if (result == null + || (Literal.of(stringValue)).comparator().compare(stringValue, (String) result) + > 0) { + result = (T) stringValue; + } + } + + isString = true; + break; + case FIXED: + case BINARY: + ByteBuffer binaryValue = + Conversions.fromByteBuffer(Types.BinaryType.get(), ByteBuffer.wrap(valueInBytes)); + if (isMin) { + if (result == null + || (Literal.of(binaryValue)) + .comparator() + .compare(binaryValue, (ByteBuffer) result) + < 0) { + result = (T) binaryValue; + } + } else { + if (result == null + || (Literal.of(binaryValue)) + .comparator() + .compare(binaryValue, (ByteBuffer) result) + > 0) { + result = (T) binaryValue; + } + } + + isBinary = true; + break; + case DECIMAL: + int precision = ((Types.DecimalType) type).precision(); + int scale = ((Types.DecimalType) type).scale(); + BigDecimal decimal = + Conversions.fromByteBuffer( + Types.DecimalType.of(precision, scale), ByteBuffer.wrap(valueInBytes)); + if (isMin) { + if (result == null + || (Literal.of(decimal)).comparator().compare(decimal, (BigDecimal) result) < 0) { + result = (T) decimal; + } + } else { + if (result == null + || (Literal.of(decimal)).comparator().compare(decimal, (BigDecimal) result) > 0) { + result = (T) decimal; + } + } + + isDecimal = true; + break; + default: + throw new UnsupportedOperationException("Data type is not supported: " + type.typeId()); + } + } + } + + if (isString) { + return (T) org.apache.spark.unsafe.types.UTF8String.fromString(result.toString()); + } + + if (isBinary) { + byte[] arr = new byte[((ByteBuffer) result).remaining()]; + ((ByteBuffer) result).get(arr); + return (T) arr; + } + + if (isDecimal) { + return (T) Decimal.apply(new scala.math.BigDecimal((BigDecimal) result)); + } + + return result; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java index 69f8ab972c27..4e1654463af0 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -19,8 +19,10 @@ package org.apache.iceberg.spark.source; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.apache.iceberg.BaseTable; import org.apache.iceberg.BatchScan; import org.apache.iceberg.IncrementalAppendScan; import org.apache.iceberg.IncrementalChangelogScan; @@ -31,7 +33,9 @@ import org.apache.iceberg.TableProperties; import org.apache.iceberg.TableScan; import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.AggregateUtil; import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.BoundAggregate; import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; @@ -44,9 +48,13 @@ import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsPushDownAggregates; import org.apache.spark.sql.connector.read.SupportsPushDownFilters; import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; import org.apache.spark.sql.connector.read.SupportsReportStatistics; @@ -59,12 +67,15 @@ public class SparkScanBuilder implements ScanBuilder, + SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics { private static final Logger LOG = LoggerFactory.getLogger(SparkScanBuilder.class); private static final Filter[] NO_FILTERS = new Filter[0]; + private StructType pushedAggregateSchema; + private InternalRow[] pushedAggregateRows; private final SparkSession spark; private final Table table; @@ -148,6 +159,115 @@ public Filter[] pushedFilters() { return pushedFilters; } + @Override + public boolean pushAggregation(Aggregation aggregation) { + if (!pushDownAggregate(aggregation)) { + return false; + } + + List boundExpressions = + Lists.newArrayListWithExpectedSize(aggregation.aggregateExpressions().length); + for (AggregateFunc aggregate : aggregation.aggregateExpressions()) { + Expression expr = SparkAggregates.convert(aggregate); + if (expr != null) { + try { + boundExpressions.add( + (BoundAggregate) Binder.bind(schema.asStruct(), expr, caseSensitive)); + } catch (ValidationException e) { + // binding to the table schema failed, so this expression cannot be pushed down + // disable aggregate push down + LOG.info("Failed to convert aggregate expression: {}. {}", aggregate, e.getMessage()); + return false; + } + } else { + // only push down aggregates iff all of them can be pushed down. + LOG.info("Cannot push down aggregate (failed to bind): {}", aggregate); + return false; + } + } + + if (!SparkPushedDownAggregateUtil.metricsModeSupportsAggregatePushDown( + table, boundExpressions)) { + LOG.info("The MetricsMode doesn't support aggregate push down."); + return false; + } + + try { + List aggFields = Lists.newArrayList(); + List aggregateIndexInTableSchema = Lists.newArrayList(); + for (int index = 0; index < boundExpressions.size(); index++) { + // Get the type for each of the pushed down aggregate, and use these Types.NestedField to + // build the schema of this data source scan, which is different from the schema of the + // table. + // e.g. SELECT COUNT(*), MAX(col1), MIN(col1), MAX(col2), MIN(col3) FROM table; + // the schema of the table is + // col1 IntegerType, col2 FloatType, col3 DecimalType + // the schema of the data source scan is + // count(*) LongType, max(col1) IntegerType, max(col2) FloatType, min(col3) DecimalType + BoundAggregate aggregate = boundExpressions.get(index); + Types.NestedField field = AggregateUtil.buildAggregateNestedField(aggregate, index + 1); + if (field.type().isNestedType()) { + // Statistics (upper_bounds and lower_bounds, null_value_counts) are not + // available for top columns, so for top columns, we can only push down Count(*). + // Statistics (upper_bounds and lower_bounds, null_value_counts) are available for + // subfields inside nested columns. Will enable push down Max, Min, Count in + // nested column in next phase. + // TODO: enable push down Count(*) for nested column and Max, Min, Count + // for subfields in nested columns. + LOG.info("Aggregate pushed down is not supported for nested type yet {}", aggregate); + return false; + } + + aggFields.add(field); + aggregateIndexInTableSchema.add( + AggregateUtil.columnIndexInTableSchema(aggregate, table, caseSensitive)); + } + + pushedAggregateSchema = SparkSchemaUtil.convert(new Schema(aggFields)); + this.pushedAggregateRows = + SparkPushedDownAggregateUtil.constructInternalRowForPushedDownAggregate( + spark, table, boundExpressions, aggregateIndexInTableSchema); + } catch (Exception e) { + LOG.info("Aggregate can't be pushed down", e.getMessage()); + return false; + } + + return true; + } + + private boolean pushDownAggregate(Aggregation aggregation) { + if (!(table instanceof BaseTable)) { + return false; + } + + if (!readConf.aggregatePushDown()) { + return false; + } + + Snapshot currentSnapshot = table.currentSnapshot(); + if (currentSnapshot != null) { + Map map = currentSnapshot.summary(); + // if there are row-level deletes in current snapshot, the statics + // maybe changed, so disable push down aggregate. + if (Integer.parseInt(map.getOrDefault("total-position-deletes", "0")) > 0 + || Integer.parseInt(map.getOrDefault("total-equality-deletes", "0")) > 0) { + LOG.info("Cannot push down aggregate (row-level deletes might change the statistics.)"); + return false; + } + } + + // If the group by expression is not the same as the partition, the statistics information + // in metadata files cannot be used to calculate min/max/count. However, if the + // group by expression is the same as the partition, the statistics information can still + // be used to calculate min/max/count, will enable aggregate push down in next phase. + // TODO: enable aggregate push down for partition col group by expression + if (aggregation.groupByExpressions().length > 0) { + LOG.info("Cannot push down aggregate (group by is not supported yet)."); + return false; + } + return true; + } + @Override public void pruneColumns(StructType requestedSchema) { StructType requestedProjection = @@ -183,6 +303,19 @@ private Schema schemaWithMetadataColumns() { @Override public Scan build() { + // if aggregates are pushed down, instead of constructing a SparkBatchQueryScan, creating file + // read tasks and sending over the tasks to Spark executors, a SparkLocalScan will be created + // and the scan is done locally on the Spark driver instead of the executors. The statistics + // info will be retrieved from manifest file and used to build a Spark internal row, which + // contains the pushed down aggregate values. + if (pushedAggregateRows != null) { + return new SparkLocalScan(table, pushedAggregateSchema, pushedAggregateRows); + } else { + return buildBatchScan(); + } + } + + private Scan buildBatchScan() { Long snapshotId = readConf.snapshotId(); Long asOfTimestamp = readConf.asOfTimestamp(); String branch = readConf.branch(); diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java new file mode 100644 index 000000000000..aac972259ebe --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.spark.sql.SparkSession; +import org.junit.After; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class TestAggregatePushDown extends SparkCatalogTestBase { + + public TestAggregatePushDown( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @BeforeClass + public static void startMetastoreAndSpark() { + SparkTestBase.metastore = new TestHiveMetastore(); + metastore.start(); + SparkTestBase.hiveConf = metastore.hiveConf(); + + SparkTestBase.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.sql.iceberg.aggregate_pushdown", "true") + .enableHiveSupport() + .getOrCreate(); + + SparkTestBase.catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDifferentDataTypesAggregatePushDownInPartitionedTable() { + testDifferentDataTypesAggregatePushDown(true); + } + + @Test + public void testDifferentDataTypesAggregatePushDownInNonPartitionedTable() { + testDifferentDataTypesAggregatePushDown(false); + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private void testDifferentDataTypesAggregatePushDown(boolean hasPartitionCol) { + String createTable; + if (hasPartitionCol) { + createTable = + "CREATE TABLE %s (id LONG, intData INT, booleanData BOOLEAN, floatData FLOAT, doubleData DOUBLE, " + + "decimalData DECIMAL(14, 2), binaryData binary) USING iceberg PARTITIONED BY (id)"; + } else { + createTable = + "CREATE TABLE %s (id LONG, intData INT, booleanData BOOLEAN, floatData FLOAT, doubleData DOUBLE, " + + "decimalData DECIMAL(14, 2), binaryData binary) USING iceberg"; + } + sql(createTable, tableName); + sql( + "INSERT INTO TABLE %s VALUES " + + "(1, null, false, null, null, 11.11, X'1111')," + + " (1, null, true, 2.222, 2.222222, 22.22, X'2222')," + + " (2, 33, false, 3.333, 3.333333, 33.33, X'3333')," + + " (2, 44, true, null, 4.444444, 44.44, X'4444')," + + " (3, 55, false, 5.555, 5.555555, 55.55, X'5555')," + + " (3, null, true, null, 6.666666, 66.66, null) ", + tableName); + + String select = + "SELECT count(*), max(id), min(id), count(id), " + + "max(intData), min(intData), count(intData), " + + "max(booleanData), min(booleanData), count(booleanData), " + + "max(floatData), min(floatData), count(floatData), " + + "max(doubleData), min(doubleData), count(doubleData), " + + "max(decimalData), min(decimalData), count(decimalData), " + + "max(binaryData), min(binaryData), count(binaryData) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("COUNT(*)") + && explainString.contains("MAX(id)") + && explainString.contains("MIN(id)") + && explainString.contains("COUNT(id)") + && explainString.contains("MAX(intData)") + && explainString.contains("MIN(intData)") + && explainString.contains("COUNT(intData)") + && explainString.contains("MAX(booleanData)") + && explainString.contains("MIN(booleanData)") + && explainString.contains("COUNT(booleanData)") + && explainString.contains("MAX(floatData)") + && explainString.contains("MIN(floatData)") + && explainString.contains("COUNT(floatData)") + && explainString.contains("MAX(doubleData)") + && explainString.contains("MIN(doubleData)") + && explainString.contains("COUNT(doubleData)") + && explainString.contains("MAX(decimalData)") + && explainString.contains("MIN(decimalData)") + && explainString.contains("COUNT(decimalData)") + && explainString.contains("MAX(binaryData)") + && explainString.contains("MIN(binaryData)") + && explainString.contains("COUNT(binaryData)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + 6L, + 3L, + 1L, + 6L, + 55, + 33, + 3L, + true, + false, + 6L, + 5.555f, + 2.222f, + 3L, + 6.666666, + 2.222222, + 5L, + new BigDecimal("66.66"), + new BigDecimal("11.11"), + 6L, + new byte[] {85, 85}, + new byte[] {17, 17}, + 5L + }); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testDateAndTimestampWithPartition() { + sql( + "CREATE TABLE %s (id bigint, data string, d date, ts timestamp) USING iceberg PARTITIONED BY (id)", + tableName); + sql( + "INSERT INTO %s VALUES (1, '1', date('2021-11-10'), null)," + + "(1, '2', date('2021-11-11'), timestamp('2021-11-11 22:22:22')), " + + "(2, '3', date('2021-11-12'), timestamp('2021-11-12 22:22:22')), " + + "(2, '4', date('2021-11-13'), timestamp('2021-11-13 22:22:22')), " + + "(3, '5', null, timestamp('2021-11-14 22:22:22')), " + + "(3, '6', date('2021-11-14'), null)", + tableName); + String select = "SELECT max(d), min(d), count(d), max(ts), min(ts), count(ts) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("MAX(d)") + && explainString.contains("MIN(d)") + && explainString.contains("COUNT(d)") + && explainString.contains("MAX(ts)") + && explainString.contains("MIN(ts)") + && explainString.contains("COUNT(ts)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + Date.valueOf("2021-11-14"), + Date.valueOf("2021-11-10"), + 5L, + Timestamp.valueOf("2021-11-14 22:22:22.0"), + Timestamp.valueOf("2021-11-11 22:22:22.0"), + 4L + }); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testAggregateNotPushDownIfOneCantPushDown() { + sql("CREATE TABLE %s (id LONG, data DOUBLE) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + String select = "SELECT COUNT(data), SUM(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("COUNT(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + } + + @Test + public void testAggregateNotPushDownIfMetaDataMissing() { + sql("CREATE TABLE %s (id LONG, data DOUBLE) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666)", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "id", "counts"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "data", "none"); + String select1 = "SELECT COUNT(data), SUM(data) FROM %s"; + + List explain1 = sql("EXPLAIN " + select1, tableName); + String explainString1 = explain1.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString1.contains("COUNT(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + String select2 = "SELECT COUNT(id) FROM %s"; + List explain2 = sql("EXPLAIN " + select2, tableName); + String explainString2 = explain2.get(0)[0].toString(); + if (explainString2.contains("COUNT(id)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + String select3 = "SELECT COUNT(id), MAX(id) FROM %s"; + explainContainsPushDownAggregates = false; + List explain3 = sql("EXPLAIN " + select3, tableName); + String explainString3 = explain3.get(0)[0].toString(); + if (explainString3.contains("COUNT(id)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + } + + @Test + public void testAggregateNotPushDownForStringType() { + sql("CREATE TABLE %s (id LONG, data STRING) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, '1111'), (1, '2222'), (2, '3333'), (2, '4444'), (3, '5555'), (3, '6666') ", + tableName); + + String select1 = "SELECT COUNT(data), SUM(data) FROM %s"; + + List explain1 = sql("EXPLAIN " + select1, tableName); + String explainString1 = explain1.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString1.contains("COUNT(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + String select2 = "SELECT COUNT(data) FROM %s"; + List explain2 = sql("EXPLAIN " + select2, tableName); + String explainString2 = explain2.get(0)[0].toString(); + if (explainString2.contains("COUNT(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select2, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L}); + assertEquals("min/max/count push down", expected, actual); + + explainContainsPushDownAggregates = false; + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DEFAULT_WRITE_METRICS_MODE, "full"); + String select3 = "SELECT COUNT(data), MAX(data) FROM %s"; + List explain3 = sql("EXPLAIN " + select3, tableName); + String explainString3 = explain3.get(0)[0].toString(); + if (explainString3.contains("COUNT(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + actual = sql(select3, tableName); + expected = Lists.newArrayList(); + expected.add(new Object[] {6L, "6666"}); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testAggregateWithComplexTypeNotPushDown() { + sql("CREATE TABLE %s (id INT, complex STRUCT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))," + + "(2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", + tableName); + String select = "SELECT max(complex), min(complex), count(complex) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("MAX(complex)") + || explainString.contains("MIN(complex)") + || explainString.contains("COUNT(complex)")) { + explainContainsPushDownAggregates = true; + } + Assert.assertFalse( + "min/max/count not pushed down for complex types", explainContainsPushDownAggregates); + } + + @Test + public void testAggregratePushDownInDeleteCopyOnWrite() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + sql("DELETE FROM %s WHERE data = 1111", tableName); + String select = "SELECT max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("MAX(data)") + && explainString.contains("MIN(data)") + && explainString.contains("COUNT(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue("min/max/count pushed down for deleted", explainContainsPushDownAggregates); + + sql(select, tableName); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6666, 2222, 5L}); + assertEquals("min/max/count push down", expected, actual); + } + + @Ignore + public void testAggregatePushDownInDeleteMergeOnRead() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (1, 3333), (2, 4444), (2, 5555), (2, 6666) ", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.FORMAT_VERSION, "2"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DELETE_MODE, "merge-on-read"); + sql("DELETE FROM %s WHERE data = 1111", tableName); + String select = "SELECT max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("MAX(data)") + && explainString.contains("MIN(data)") + && explainString.contains("COUNT(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "min/max/count not pushed down for deleted", explainContainsPushDownAggregates); + } + + @Test + public void testAggregrateWithGroupByNotPushDown() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg PARTITIONED BY (id)", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + String select = "SELECT max(data), min(data) FROM %s GROUP BY id"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString(); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("MAX(data)") || explainString.contains("MIN(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse("min/max not pushed down", explainContainsPushDownAggregates); + } +}