diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index c01306ccf5b9..ee94c2b2fdbe 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -27,13 +27,13 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.AttributeSet import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils import org.apache.spark.sql.catalyst.expressions.IsNotNull import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.LeftAnti @@ -390,7 +390,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with Predicat } private def resolveAttrRef(ref: NamedReference, plan: LogicalPlan): AttributeReference = { - ExtendedV2ExpressionUtils.resolveRef[AttributeReference](ref, plan) + V2ExpressionUtils.resolveRef[AttributeReference](ref, plan) } private def buildMergeDeltaProjections( diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala index b460f648d28b..abadab4e5347 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ProjectingInternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils +import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.WriteDeltaProjections import org.apache.spark.sql.connector.write.RowLevelOperation @@ -73,7 +73,7 @@ trait RewriteRowLevelIcebergCommand extends RewriteRowLevelCommand { operation match { case supportsDelta: SupportsDelta => - val rowIdAttrs = ExtendedV2ExpressionUtils.resolveRefs[AttributeReference]( + val rowIdAttrs = V2ExpressionUtils.resolveRefs[AttributeReference]( supportsDelta.rowId.toSeq, relation) diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtendedV2ExpressionUtils.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtendedV2ExpressionUtils.scala deleted file mode 100644 index 16ff67a70522..000000000000 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtendedV2ExpressionUtils.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression} -import org.apache.spark.sql.connector.expressions.{SortDirection => V2SortDirection} -import org.apache.spark.sql.connector.expressions.{NullOrdering => V2NullOrdering} -import org.apache.spark.sql.connector.expressions.BucketTransform -import org.apache.spark.sql.connector.expressions.DaysTransform -import org.apache.spark.sql.connector.expressions.FieldReference -import org.apache.spark.sql.connector.expressions.HoursTransform -import org.apache.spark.sql.connector.expressions.IdentityTransform -import org.apache.spark.sql.connector.expressions.MonthsTransform -import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.SortValue -import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.connector.expressions.TruncateTransform -import org.apache.spark.sql.connector.expressions.YearsTransform -import org.apache.spark.sql.errors.QueryCompilationErrors - -/** - * A class that is inspired by V2ExpressionUtils in Spark but supports Iceberg transforms. - */ -object ExtendedV2ExpressionUtils extends SQLConfHelper { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper - - def resolveRef[T <: NamedExpression](ref: NamedReference, plan: LogicalPlan): T = { - plan.resolve(ref.fieldNames.toSeq, conf.resolver) match { - case Some(namedExpr) => - namedExpr.asInstanceOf[T] - case None => - val name = ref.fieldNames.toSeq.quoted - val outputString = plan.output.map(_.name).mkString(",") - throw QueryCompilationErrors.cannotResolveAttributeError(name, outputString) - } - } - - def resolveRefs[T <: NamedExpression](refs: Seq[NamedReference], plan: LogicalPlan): Seq[T] = { - refs.map(ref => resolveRef[T](ref, plan)) - } - - def toCatalyst(expr: V2Expression, query: LogicalPlan): Expression = { - expr match { - case SortValue(child, direction, nullOrdering) => - val catalystChild = toCatalyst(child, query) - SortOrder(catalystChild, toCatalyst(direction), toCatalyst(nullOrdering), Seq.empty) - case IdentityTransform(ref) => - resolveRef[NamedExpression](ref, query) - case t: Transform if BucketTransform.unapply(t).isDefined => - t match { - // sort columns will be empty for bucket. - case BucketTransform(numBuckets, cols, _) => - IcebergBucketTransform(numBuckets, resolveRef[NamedExpression](cols.head, query)) - case _ => t.asInstanceOf[Expression] - // do nothing - } - case TruncateTransform(length, ref) => - IcebergTruncateTransform(resolveRef[NamedExpression](ref, query), length) - case YearsTransform(ref) => - IcebergYearTransform(resolveRef[NamedExpression](ref, query)) - case MonthsTransform(ref) => - IcebergMonthTransform(resolveRef[NamedExpression](ref, query)) - case DaysTransform(ref) => - IcebergDayTransform(resolveRef[NamedExpression](ref, query)) - case HoursTransform(ref) => - IcebergHourTransform(resolveRef[NamedExpression](ref, query)) - case ref: FieldReference => - resolveRef[NamedExpression](ref, query) - case _ => - throw new AnalysisException(s"$expr is not currently supported") - } - } - - private def toCatalyst(direction: V2SortDirection): SortDirection = direction match { - case V2SortDirection.ASCENDING => Ascending - case V2SortDirection.DESCENDING => Descending - } - - private def toCatalyst(nullOrdering: V2NullOrdering): NullOrdering = nullOrdering match { - case V2NullOrdering.NULLS_FIRST => NullsFirst - case V2NullOrdering.NULLS_LAST => NullsLast - } -} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/WriteIcebergDelta.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/WriteIcebergDelta.scala index 10db698b9b91..8495856fb6b0 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/WriteIcebergDelta.scala +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/WriteIcebergDelta.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.analysis.NamedRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN import org.apache.spark.sql.catalyst.util.WriteDeltaProjections @@ -80,7 +80,7 @@ case class WriteIcebergDelta( } private def rowIdAttrsResolved: Boolean = { - val rowIdAttrs = ExtendedV2ExpressionUtils.resolveRefs[AttributeReference]( + val rowIdAttrs = V2ExpressionUtils.resolveRefs[AttributeReference]( operation.rowId.toSeq, originalTable) @@ -92,7 +92,7 @@ case class WriteIcebergDelta( private def metadataAttrsResolved: Boolean = { projections.metadataProjection match { case Some(projection) => - val metadataAttrs = ExtendedV2ExpressionUtils.resolveRefs[AttributeReference]( + val metadataAttrs = V2ExpressionUtils.resolveRefs[AttributeReference]( operation.requiredMetadataAttributes.toSeq, originalTable) diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDistributionAndOrderingUtils.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDistributionAndOrderingUtils.scala deleted file mode 100644 index 8c37b1b75924..000000000000 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDistributionAndOrderingUtils.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils.toCatalyst -import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.RepartitionByExpression -import org.apache.spark.sql.catalyst.plans.logical.Sort -import org.apache.spark.sql.connector.distributions.ClusteredDistribution -import org.apache.spark.sql.connector.distributions.OrderedDistribution -import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution -import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering -import org.apache.spark.sql.connector.write.Write -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf -import scala.collection.compat.immutable.ArraySeq - -/** - * A rule that is inspired by DistributionAndOrderingUtils in Spark but supports Iceberg transforms. - * - * Note that similarly to the original rule in Spark, it does not let AQE pick the number of shuffle - * partitions. See SPARK-34230 for context. - */ -object ExtendedDistributionAndOrderingUtils { - - def prepareQuery(write: Write, query: LogicalPlan, conf: SQLConf): LogicalPlan = write match { - case write: RequiresDistributionAndOrdering => - val numPartitions = write.requiredNumPartitions() - val distribution = write.requiredDistribution match { - case d: OrderedDistribution => d.ordering.map(e => toCatalyst(e, query)) - case d: ClusteredDistribution => d.clustering.map(e => toCatalyst(e, query)) - case _: UnspecifiedDistribution => Array.empty[Expression] - } - - val queryWithDistribution = if (distribution.nonEmpty) { - val finalNumPartitions = if (numPartitions > 0) { - numPartitions - } else { - conf.numShufflePartitions - } - // the conversion to catalyst expressions above produces SortOrder expressions - // for OrderedDistribution and generic expressions for ClusteredDistribution - // this allows RepartitionByExpression to pick either range or hash partitioning - RepartitionByExpression(ArraySeq.unsafeWrapArray(distribution), query, finalNumPartitions) - } else if (numPartitions > 0) { - throw QueryCompilationErrors.numberOfPartitionsNotAllowedWithUnspecifiedDistributionError() - } else { - query - } - - val ordering = write.requiredOrdering.toSeq - .map(e => toCatalyst(e, query)) - .asInstanceOf[Seq[SortOrder]] - - val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { - Sort(ordering, global = false, queryWithDistribution) - } else { - queryWithDistribution - } - - queryWithDistributionAndOrdering - - case _ => - query - } -} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala index 83b793925db2..0d13f6a5230b 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala @@ -22,100 +22,40 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.Optional import java.util.UUID import org.apache.spark.sql.catalyst.expressions.PredicateHelper -import org.apache.spark.sql.catalyst.plans.logical.AppendData import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.OverwriteByExpression -import org.apache.spark.sql.catalyst.plans.logical.OverwritePartitionsDynamic import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData import org.apache.spark.sql.catalyst.plans.logical.WriteIcebergDelta import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.WriteDeltaProjections -import org.apache.spark.sql.catalyst.utils.PlanUtils.isIcebergRelation import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.write.DeltaWriteBuilder import org.apache.spark.sql.connector.write.LogicalWriteInfoImpl -import org.apache.spark.sql.connector.write.SupportsDynamicOverwrite -import org.apache.spark.sql.connector.write.SupportsOverwrite -import org.apache.spark.sql.connector.write.SupportsTruncate import org.apache.spark.sql.connector.write.WriteBuilder -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.AlwaysTrue -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType /** - * A rule that is inspired by V2Writes in Spark but supports Iceberg transforms. + * A rule that is inspired by V2Writes in Spark but supports Iceberg specific plans. */ object ExtendedV2Writes extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case a @ AppendData(r: DataSourceV2Relation, query, options, _, None, _) if isIcebergRelation(r) => - val writeBuilder = newWriteBuilder(r.table, query.schema, options) - val write = writeBuilder.build() - val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) - a.copy(write = Some(write), query = newQuery) - - case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None, _) - if isIcebergRelation(r) => - // fail if any filter cannot be converted. correctness depends on removing all matching data. - val filters = splitConjunctivePredicates(deleteExpr).flatMap { pred => - val filter = DataSourceStrategy.translateFilter(pred, supportNestedPredicatePushdown = true) - if (filter.isEmpty) { - throw QueryCompilationErrors.cannotTranslateExpressionToSourceFilterError(pred) - } - filter - }.toArray - - val table = r.table - val writeBuilder = newWriteBuilder(table, query.schema, options) - val write = writeBuilder match { - case builder: SupportsTruncate if isTruncate(filters) => - builder.truncate().build() - case builder: SupportsOverwrite => - builder.overwrite(filters).build() - case _ => - throw QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table) - } - - val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) - o.copy(write = Some(write), query = newQuery) - - case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) - if isIcebergRelation(r) => - val table = r.table - val writeBuilder = newWriteBuilder(table, query.schema, options) - val write = writeBuilder match { - case builder: SupportsDynamicOverwrite => - builder.overwriteDynamicPartitions().build() - case _ => - throw QueryExecutionErrors.dynamicPartitionOverwriteUnsupportedByTableError(table) - } - val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) - o.copy(write = Some(write), query = newQuery) - case rd @ ReplaceIcebergData(r: DataSourceV2Relation, query, _, None) => val rowSchema = StructType.fromAttributes(rd.dataInput) val writeBuilder = newWriteBuilder(r.table, rowSchema, Map.empty) val write = writeBuilder.build() - val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery)) case wd @ WriteIcebergDelta(r: DataSourceV2Relation, query, _, projections, None) => val deltaWriteBuilder = newDeltaWriteBuilder(r.table, Map.empty, projections) val deltaWrite = deltaWriteBuilder.build() - val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(deltaWrite, query, conf) + val newQuery = DistributionAndOrderingUtils.prepareQuery(deltaWrite, query, r.funCatalog) wd.copy(write = Some(deltaWrite), query = newQuery) } - private def isTruncate(filters: Array[Filter]): Boolean = { - filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] - } - private def newWriteBuilder( table: Table, rowSchema: StructType, diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala index f5d5affe9e92..de26ea4486cd 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.AttributeMap import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable @@ -80,8 +80,8 @@ case class RowLevelCommandDynamicPruning(spark: SparkSession) extends Rule[Logic val matchingRowsPlan = buildMatchingRowsPlan(relation, command) val filterAttrs = ArraySeq.unsafeWrapArray(scan.filterAttributes) - val buildKeys = ExtendedV2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan) - val pruningKeys = ExtendedV2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r) + val buildKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan) + val pruningKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r) val dynamicPruningCond = buildDynamicPruningCond(matchingRowsPlan, buildKeys, pruningKeys) Filter(dynamicPruningCond, r) diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java deleted file mode 100644 index 8d2e10ea17eb..000000000000 --- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.spark.extensions; - -import java.math.BigDecimal; -import java.util.Map; -import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; -import org.apache.spark.sql.Column; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.IcebergTruncateTransform; -import org.junit.After; -import org.junit.Test; - -public class TestIcebergExpressions extends SparkExtensionsTestBase { - - public TestIcebergExpressions( - String catalogName, String implementation, Map config) { - super(catalogName, implementation, config); - } - - @After - public void removeTables() { - sql("DROP TABLE IF EXISTS %s", tableName); - sql("DROP VIEW IF EXISTS emp"); - sql("DROP VIEW IF EXISTS v"); - } - - @Test - public void testTruncateExpressions() { - sql( - "CREATE TABLE %s ( " - + " int_c INT, long_c LONG, dec_c DECIMAL(4, 2), str_c STRING, binary_c BINARY " - + ") USING iceberg", - tableName); - - sql( - "CREATE TEMPORARY VIEW emp " - + "AS SELECT * FROM VALUES (101, 10001, 10.65, '101-Employee', CAST('1234' AS BINARY)) " - + "AS EMP(int_c, long_c, dec_c, str_c, binary_c)"); - - sql("INSERT INTO %s SELECT * FROM emp", tableName); - - Dataset df = spark.sql("SELECT * FROM " + tableName); - df.select( - new Column(new IcebergTruncateTransform(df.col("int_c").expr(), 2)).as("int_c"), - new Column(new IcebergTruncateTransform(df.col("long_c").expr(), 2)).as("long_c"), - new Column(new IcebergTruncateTransform(df.col("dec_c").expr(), 50)).as("dec_c"), - new Column(new IcebergTruncateTransform(df.col("str_c").expr(), 2)).as("str_c"), - new Column(new IcebergTruncateTransform(df.col("binary_c").expr(), 2)).as("binary_c")) - .createOrReplaceTempView("v"); - - assertEquals( - "Should have expected rows", - ImmutableList.of(row(100, 10000L, new BigDecimal("10.50"), "10", "12")), - sql("SELECT int_c, long_c, dec_c, str_c, CAST(binary_c AS STRING) FROM v")); - } -} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java index 44aca898b696..6cda93f8674e 100644 --- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java @@ -41,9 +41,11 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.internal.SQLConf; import org.junit.After; import org.junit.Assert; import org.junit.Assume; +import org.junit.BeforeClass; import org.junit.Test; public class TestRewriteDataFilesProcedure extends SparkExtensionsTestBase { @@ -55,6 +57,12 @@ public TestRewriteDataFilesProcedure( super(catalogName, implementation, config); } + @BeforeClass + public static void setupSpark() { + // disable AQE as tests assume that writes generate a particular number of files + spark.conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"); + } + @After public void removeTable() { sql("DROP TABLE IF EXISTS %s", tableName); diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java index 2e5e383baf42..38f15a42958c 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java @@ -18,18 +18,13 @@ */ package org.apache.iceberg.spark; -import org.apache.iceberg.spark.functions.SparkFunctions; import org.apache.iceberg.spark.procedures.SparkProcedures; import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; import org.apache.iceberg.spark.source.HasIcebergCatalog; -import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; -import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; -import org.apache.spark.sql.connector.catalog.FunctionCatalog; import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.catalog.StagingTableCatalog; import org.apache.spark.sql.connector.catalog.SupportsNamespaces; -import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; import org.apache.spark.sql.connector.iceberg.catalog.Procedure; import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog; @@ -38,7 +33,7 @@ abstract class BaseCatalog ProcedureCatalog, SupportsNamespaces, HasIcebergCatalog, - FunctionCatalog { + SupportsFunctions { @Override public Procedure loadProcedure(Identifier ident) throws NoSuchProcedureException { @@ -58,35 +53,17 @@ public Procedure loadProcedure(Identifier ident) throws NoSuchProcedureException } @Override - public Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceException { - if (namespace.length == 0 || isSystemNamespace(namespace)) { - return SparkFunctions.list().stream() - .map(name -> Identifier.of(namespace, name)) - .toArray(Identifier[]::new); - } else if (namespaceExists(namespace)) { - return new Identifier[0]; - } - - throw new NoSuchNamespaceException(namespace); - } - - @Override - public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException { - String[] namespace = ident.namespace(); - String name = ident.name(); - + public boolean isFunctionNamespace(String[] namespace) { // Allow for empty namespace, as Spark's storage partitioned joins look up // the corresponding functions to generate transforms for partitioning // with an empty namespace, such as `bucket`. // Otherwise, use `system` namespace. - if (namespace.length == 0 || isSystemNamespace(namespace)) { - UnboundFunction func = SparkFunctions.load(name); - if (func != null) { - return func; - } - } + return namespace.length == 0 || isSystemNamespace(namespace); + } - throw new NoSuchFunctionException(ident); + @Override + public boolean isExistingNamespace(String[] namespace) { + return namespaceExists(namespace); } private static boolean isSystemNamespace(String[] namespace) { diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java index 52d68db2e4f9..781f61b33f0e 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java @@ -53,7 +53,7 @@ public SortOrder truncate( String sourceName, int id, int width, SortDirection direction, NullOrder nullOrder) { return Expressions.sort( Expressions.apply( - "truncate", Expressions.column(quotedName(id)), Expressions.literal(width)), + "truncate", Expressions.literal(width), Expressions.column(quotedName(id))), toSpark(direction), toSpark(nullOrder)); } diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkCachedTableCatalog.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkCachedTableCatalog.java index 2533b3bd75b5..21317526d2aa 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkCachedTableCatalog.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkCachedTableCatalog.java @@ -43,7 +43,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** An internal table catalog that is capable of loading tables from a cache. */ -public class SparkCachedTableCatalog implements TableCatalog { +public class SparkCachedTableCatalog implements TableCatalog, SupportsFunctions { private static final String CLASS_NAME = SparkCachedTableCatalog.class.getName(); private static final Splitter COMMA = Splitter.on(","); diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/connector/expressions/TruncateTransform.scala b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFunctionCatalog.java similarity index 55% rename from spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/connector/expressions/TruncateTransform.scala rename to spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFunctionCatalog.java index 2a3269e2db1d..2183b9e5df4d 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/connector/expressions/TruncateTransform.scala +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFunctionCatalog.java @@ -16,23 +16,30 @@ * specific language governing permissions and limitations * under the License. */ +package org.apache.iceberg.spark; -package org.apache.spark.sql.connector.expressions +import org.apache.spark.sql.util.CaseInsensitiveStringMap; -import org.apache.spark.sql.types.IntegerType +/** + * A function catalog that can be used to resolve Iceberg functions without a metastore connection. + */ +public class SparkFunctionCatalog implements SupportsFunctions { + + private static final SparkFunctionCatalog INSTANCE = new SparkFunctionCatalog(); + + private String name = "iceberg-function-catalog"; + + public static SparkFunctionCatalog get() { + return INSTANCE; + } + + @Override + public void initialize(String catalogName, CaseInsensitiveStringMap options) { + this.name = catalogName; + } -private[sql] object TruncateTransform { - def unapply(expr: Expression): Option[(Int, FieldReference)] = expr match { - case transform: Transform => - transform match { - case NamedTransform("truncate", Seq(Ref(seq: Seq[String]), Lit(value: Int, IntegerType))) => - Some((value, FieldReference(seq))) - case NamedTransform("truncate", Seq(Lit(value: Int, IntegerType), Ref(seq: Seq[String]))) => - Some((value, FieldReference(seq))) - case _ => - None - } - case _ => - None + @Override + public String name() { + return name; } } diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SupportsFunctions.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SupportsFunctions.java new file mode 100644 index 000000000000..34897d2b4c01 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SupportsFunctions.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.spark.functions.SparkFunctions; +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; + +interface SupportsFunctions extends FunctionCatalog { + + default boolean isFunctionNamespace(String[] namespace) { + return namespace.length == 0; + } + + default boolean isExistingNamespace(String[] namespace) { + return namespace.length == 0; + } + + default Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceException { + if (isFunctionNamespace(namespace)) { + return SparkFunctions.list().stream() + .map(name -> Identifier.of(namespace, name)) + .toArray(Identifier[]::new); + } else if (isExistingNamespace(namespace)) { + return new Identifier[0]; + } + + throw new NoSuchNamespaceException(namespace); + } + + default UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException { + String[] namespace = ident.namespace(); + String name = ident.name(); + + if (isFunctionNamespace(namespace)) { + UnboundFunction func = SparkFunctions.load(name); + if (func != null) { + return func; + } + } + + throw new NoSuchFunctionException(ident); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkShufflingDataRewriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkShufflingDataRewriter.java index 1add6383c618..53d5f49b9f73 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkShufflingDataRewriter.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkShufflingDataRewriter.java @@ -21,11 +21,13 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Function; import org.apache.iceberg.FileScanTask; import org.apache.iceberg.Table; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; +import org.apache.iceberg.spark.SparkFunctionCatalog; import org.apache.iceberg.spark.SparkReadOptions; import org.apache.iceberg.spark.SparkWriteOptions; import org.apache.iceberg.util.PropertyUtil; @@ -34,11 +36,13 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; -import org.apache.spark.sql.catalyst.utils.DistributionAndOrderingUtils$; +import org.apache.spark.sql.connector.distributions.Distribution; import org.apache.spark.sql.connector.distributions.Distributions; import org.apache.spark.sql.connector.distributions.OrderedDistribution; import org.apache.spark.sql.connector.expressions.SortOrder; -import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering; +import org.apache.spark.sql.execution.datasources.v2.DistributionAndOrderingUtils$; +import scala.Option; abstract class SparkShufflingDataRewriter extends SparkSizeBasedDataRewriter { @@ -61,7 +65,10 @@ protected SparkShufflingDataRewriter(SparkSession spark, Table table) { super(spark, table); } - protected abstract Dataset sortedDF(Dataset df, List group); + protected abstract org.apache.iceberg.SortOrder sortOrder(); + + protected abstract Dataset sortedDF( + Dataset df, Function, Dataset> sortFunc); @Override public Set validOptions() { @@ -79,9 +86,6 @@ public void init(Map options) { @Override public void doRewrite(String groupId, List group) { - // the number of shuffle partition controls the number of output files - spark().conf().set(SQLConf.SHUFFLE_PARTITIONS().key(), numShufflePartitions(group)); - Dataset scanDF = spark() .read() @@ -89,7 +93,7 @@ public void doRewrite(String groupId, List group) { .option(SparkReadOptions.SCAN_TASK_SET_ID, groupId) .load(groupId); - Dataset sortedDF = sortedDF(scanDF, group); + Dataset sortedDF = sortedDF(scanDF, sortFunction(group)); sortedDF .write() @@ -101,30 +105,35 @@ public void doRewrite(String groupId, List group) { .save(groupId); } - protected Dataset sort(Dataset df, org.apache.iceberg.SortOrder sortOrder) { - SortOrder[] ordering = SparkDistributionAndOrderingUtil.convert(sortOrder); - OrderedDistribution distribution = Distributions.ordered(ordering); - SQLConf conf = spark().sessionState().conf(); - LogicalPlan plan = df.logicalPlan(); - LogicalPlan sortPlan = - DistributionAndOrderingUtils$.MODULE$.prepareQuery(distribution, ordering, plan, conf); - return new Dataset<>(spark(), sortPlan, df.encoder()); + private Function, Dataset> sortFunction(List group) { + SortOrder[] ordering = SparkDistributionAndOrderingUtil.convert(outputSortOrder(group)); + int numShufflePartitions = numShufflePartitions(group); + return (df) -> transformPlan(df, plan -> sortPlan(plan, ordering, numShufflePartitions)); } - protected org.apache.iceberg.SortOrder outputSortOrder( - List group, org.apache.iceberg.SortOrder sortOrder) { + private LogicalPlan sortPlan(LogicalPlan plan, SortOrder[] ordering, int numShufflePartitions) { + SparkFunctionCatalog catalog = SparkFunctionCatalog.get(); + OrderedWrite write = new OrderedWrite(ordering, numShufflePartitions); + return DistributionAndOrderingUtils$.MODULE$.prepareQuery(write, plan, Option.apply(catalog)); + } + + private Dataset transformPlan(Dataset df, Function func) { + return new Dataset<>(spark(), func.apply(df.logicalPlan()), df.encoder()); + } + + private org.apache.iceberg.SortOrder outputSortOrder(List group) { boolean includePartitionColumns = !group.get(0).spec().equals(table().spec()); if (includePartitionColumns) { // build in the requirement for partition sorting into our sort order // as the original spec for this group does not match the output spec - return SortOrderUtil.buildSortOrder(table(), sortOrder); + return SortOrderUtil.buildSortOrder(table(), sortOrder()); } else { - return sortOrder; + return sortOrder(); } } - private long numShufflePartitions(List group) { - long numOutputFiles = numOutputFiles((long) (inputSize(group) * compressionFactor)); + private int numShufflePartitions(List group) { + int numOutputFiles = (int) numOutputFiles((long) (inputSize(group) * compressionFactor)); return Math.max(1, numOutputFiles); } @@ -135,4 +144,36 @@ private double compressionFactor(Map options) { value > 0, "'%s' is set to %s but must be > 0", COMPRESSION_FACTOR, value); return value; } + + private static class OrderedWrite implements RequiresDistributionAndOrdering { + private final OrderedDistribution distribution; + private final SortOrder[] ordering; + private final int numShufflePartitions; + + OrderedWrite(SortOrder[] ordering, int numShufflePartitions) { + this.distribution = Distributions.ordered(ordering); + this.ordering = ordering; + this.numShufflePartitions = numShufflePartitions; + } + + @Override + public Distribution requiredDistribution() { + return distribution; + } + + @Override + public boolean distributionStrictlyRequired() { + return true; + } + + @Override + public int requiredNumPartitions() { + return numShufflePartitions; + } + + @Override + public SortOrder[] requiredOrdering() { + return ordering; + } + } } diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortDataRewriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortDataRewriter.java index 4615f3cebc92..1f70d4d7ca9d 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortDataRewriter.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortDataRewriter.java @@ -18,8 +18,7 @@ */ package org.apache.iceberg.spark.actions; -import java.util.List; -import org.apache.iceberg.FileScanTask; +import java.util.function.Function; import org.apache.iceberg.SortOrder; import org.apache.iceberg.Table; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; @@ -54,7 +53,12 @@ public String description() { } @Override - protected Dataset sortedDF(Dataset df, List group) { - return sort(df, outputSortOrder(group, sortOrder)); + protected SortOrder sortOrder() { + return sortOrder; + } + + @Override + protected Dataset sortedDF(Dataset df, Function, Dataset> sortFunc) { + return sortFunc.apply(df); } } diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderDataRewriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderDataRewriter.java index 68db76d37fcb..91eaa91f6889 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderDataRewriter.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderDataRewriter.java @@ -23,7 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Set; -import org.apache.iceberg.FileScanTask; +import java.util.function.Function; import org.apache.iceberg.NullOrder; import org.apache.iceberg.Schema; import org.apache.iceberg.SortDirection; @@ -104,9 +104,14 @@ public void init(Map options) { } @Override - protected Dataset sortedDF(Dataset df, List group) { + protected SortOrder sortOrder() { + return Z_SORT_ORDER; + } + + @Override + protected Dataset sortedDF(Dataset df, Function, Dataset> sortFunc) { Dataset zValueDF = df.withColumn(Z_COLUMN, zValue(df)); - Dataset sortedDF = sort(zValueDF, outputSortOrder(group, Z_SORT_ORDER)); + Dataset sortedDF = sortFunc.apply(zValueDF); return sortedDF.drop(Z_COLUMN); } diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java index 133ca45b4603..30f04659dfc9 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java @@ -27,7 +27,6 @@ import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; import org.apache.iceberg.spark.SparkFilters; import org.apache.iceberg.spark.SparkSchemaUtil; @@ -146,15 +145,8 @@ public Write build() { SortOrder[] ordering; if (useTableDistributionAndOrdering) { - if (Spark3Util.extensionsEnabled(spark) || allIdentityTransforms(table.spec())) { - distribution = buildRequiredDistribution(); - ordering = buildRequiredOrdering(distribution); - } else { - LOG.warn( - "Skipping distribution/ordering: extensions are disabled and spec contains unsupported transforms"); - distribution = Distributions.unspecified(); - ordering = NO_ORDERING; - } + distribution = buildRequiredDistribution(); + ordering = buildRequiredOrdering(distribution); } else { LOG.info("Skipping distribution/ordering: disabled per job configuration"); distribution = Distributions.unspecified(); diff --git a/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpressions.scala b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpressions.scala deleted file mode 100644 index dffac82af791..000000000000 --- a/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpressions.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import java.nio.ByteBuffer -import java.nio.CharBuffer -import java.nio.charset.StandardCharsets -import java.util.function -import org.apache.iceberg.spark.SparkSchemaUtil -import org.apache.iceberg.transforms.Transform -import org.apache.iceberg.transforms.Transforms -import org.apache.iceberg.types.Type -import org.apache.iceberg.types.Types -import org.apache.iceberg.util.ByteBuffers -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.AbstractDataType -import org.apache.spark.sql.types.BinaryType -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.types.Decimal -import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.sql.types.StringType -import org.apache.spark.sql.types.TimestampType -import org.apache.spark.unsafe.types.UTF8String - -abstract class IcebergTransformExpression - extends UnaryExpression with CodegenFallback with NullIntolerant { - - @transient lazy val icebergInputType: Type = SparkSchemaUtil.convert(child.dataType) -} - -abstract class IcebergTimeTransform - extends IcebergTransformExpression with ImplicitCastInputTypes { - - def transform: function.Function[Any, Integer] - - override protected def nullSafeEval(value: Any): Any = { - transform(value).toInt - } - - override def dataType: DataType = IntegerType - - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) -} - -case class IcebergYearTransform(child: Expression) - extends IcebergTimeTransform { - - @transient lazy val transform: function.Function[Any, Integer] = Transforms.year[Any]().bind(icebergInputType) - - override protected def withNewChildInternal(newChild: Expression): Expression = { - copy(child = newChild) - } -} - -case class IcebergMonthTransform(child: Expression) - extends IcebergTimeTransform { - - @transient lazy val transform: function.Function[Any, Integer] = Transforms.month[Any]().bind(icebergInputType) - - override protected def withNewChildInternal(newChild: Expression): Expression = { - copy(child = newChild) - } -} - -case class IcebergDayTransform(child: Expression) - extends IcebergTimeTransform { - - @transient lazy val transform: function.Function[Any, Integer] = Transforms.day[Any]().bind(icebergInputType) - - override protected def withNewChildInternal(newChild: Expression): Expression = { - copy(child = newChild) - } -} - -case class IcebergHourTransform(child: Expression) - extends IcebergTimeTransform { - - @transient lazy val transform: function.Function[Any, Integer] = Transforms.hour[Any]().bind(icebergInputType) - - override protected def withNewChildInternal(newChild: Expression): Expression = { - copy(child = newChild) - } -} - -case class IcebergBucketTransform(numBuckets: Int, child: Expression) extends IcebergTransformExpression { - - @transient lazy val bucketFunc: Any => Int = child.dataType match { - case _: DecimalType => - val t = Transforms.bucket[Any](numBuckets).bind(icebergInputType) - d: Any => t(d.asInstanceOf[Decimal].toJavaBigDecimal).toInt - case _: StringType => - // the spec requires that the hash of a string is equal to the hash of its UTF-8 encoded bytes - // TODO: pass bytes without the copy out of the InternalRow - val t = Transforms.bucket[ByteBuffer](numBuckets).bind(Types.BinaryType.get()) - s: Any => t(ByteBuffer.wrap(s.asInstanceOf[UTF8String].getBytes)).toInt - case _ => - val t = Transforms.bucket[Any](numBuckets).bind(icebergInputType) - a: Any => t(a).toInt - } - - override protected def nullSafeEval(value: Any): Any = { - bucketFunc(value) - } - - override def dataType: DataType = IntegerType - - override protected def withNewChildInternal(newChild: Expression): Expression = { - copy(child = newChild) - } -} - -case class IcebergTruncateTransform(child: Expression, width: Int) extends IcebergTransformExpression { - - @transient lazy val truncateFunc: Any => Any = child.dataType match { - case _: DecimalType => - val t = Transforms.truncate[java.math.BigDecimal](width).bind(icebergInputType) - d: Any => Decimal.apply(t(d.asInstanceOf[Decimal].toJavaBigDecimal)) - case _: StringType => - val t = Transforms.truncate[CharSequence](width).bind(icebergInputType) - s: Any => { - val charSequence = t(StandardCharsets.UTF_8.decode(ByteBuffer.wrap(s.asInstanceOf[UTF8String].getBytes))) - val bb = StandardCharsets.UTF_8.encode(CharBuffer.wrap(charSequence)); - UTF8String.fromBytes(ByteBuffers.toByteArray(bb)) - } - case _: BinaryType => - val t = Transforms.truncate[ByteBuffer](width).bind(icebergInputType) - s: Any => ByteBuffers.toByteArray(t(ByteBuffer.wrap(s.asInstanceOf[Array[Byte]]))) - case _ => - val t = Transforms.truncate[Any](width).bind(icebergInputType) - a: Any => t(a) - } - - override protected def nullSafeEval(value: Any): Any = { - truncateFunc(value) - } - - override def dataType: DataType = child.dataType - - override protected def withNewChildInternal(newChild: Expression): Expression = { - copy(child = newChild) - } -} diff --git a/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala deleted file mode 100644 index 94b6f651a0df..000000000000 --- a/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.catalyst.utils - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.expressions.IcebergBucketTransform -import org.apache.spark.sql.catalyst.expressions.IcebergDayTransform -import org.apache.spark.sql.catalyst.expressions.IcebergHourTransform -import org.apache.spark.sql.catalyst.expressions.IcebergMonthTransform -import org.apache.spark.sql.catalyst.expressions.IcebergTruncateTransform -import org.apache.spark.sql.catalyst.expressions.IcebergYearTransform -import org.apache.spark.sql.catalyst.expressions.NamedExpression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.RepartitionByExpression -import org.apache.spark.sql.catalyst.plans.logical.Sort -import org.apache.spark.sql.connector.distributions.ClusteredDistribution -import org.apache.spark.sql.connector.distributions.Distribution -import org.apache.spark.sql.connector.distributions.OrderedDistribution -import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution -import org.apache.spark.sql.connector.expressions.ApplyTransform -import org.apache.spark.sql.connector.expressions.BucketTransform -import org.apache.spark.sql.connector.expressions.DaysTransform -import org.apache.spark.sql.connector.expressions.Expression -import org.apache.spark.sql.connector.expressions.FieldReference -import org.apache.spark.sql.connector.expressions.HoursTransform -import org.apache.spark.sql.connector.expressions.IdentityTransform -import org.apache.spark.sql.connector.expressions.Literal -import org.apache.spark.sql.connector.expressions.MonthsTransform -import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.NullOrdering -import org.apache.spark.sql.connector.expressions.SortDirection -import org.apache.spark.sql.connector.expressions.SortOrder -import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.connector.expressions.YearsTransform -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.types.IntegerType -import scala.collection.compat.immutable.ArraySeq - -object DistributionAndOrderingUtils { - - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - - def prepareQuery( - requiredDistribution: Distribution, - requiredOrdering: Array[SortOrder], - query: LogicalPlan, - conf: SQLConf): LogicalPlan = { - - val resolver = conf.resolver - - val distribution = requiredDistribution match { - case d: OrderedDistribution => - d.ordering.map(e => toCatalyst(e, query, resolver)) - case d: ClusteredDistribution => - d.clustering.map(e => toCatalyst(e, query, resolver)) - case _: UnspecifiedDistribution => - Array.empty[catalyst.expressions.Expression] - } - - val queryWithDistribution = if (distribution.nonEmpty) { - // the conversion to catalyst expressions above produces SortOrder expressions - // for OrderedDistribution and generic expressions for ClusteredDistribution - // this allows RepartitionByExpression to pick either range or hash partitioning - RepartitionByExpression(distribution.toSeq, query, None) - } else { - query - } - - val ordering = requiredOrdering - .map(e => toCatalyst(e, query, resolver).asInstanceOf[catalyst.expressions.SortOrder]) - - val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { - Sort(ArraySeq.unsafeWrapArray(ordering), global = false, queryWithDistribution) - } else { - queryWithDistribution - } - - queryWithDistributionAndOrdering - } - - private def toCatalyst( - expr: Expression, - query: LogicalPlan, - resolver: Resolver): catalyst.expressions.Expression = { - - // we cannot perform the resolution in the analyzer since we need to optimize expressions - // in nodes like OverwriteByExpression before constructing a logical write - def resolve(parts: Seq[String]): NamedExpression = { - query.resolve(parts, resolver) match { - case Some(attr) => - attr - case None => - val ref = parts.quoted - throw new AnalysisException(s"Cannot resolve '$ref' using ${query.output}") - } - } - - expr match { - case s: SortOrder => - val catalystChild = toCatalyst(s.expression(), query, resolver) - catalyst.expressions.SortOrder(catalystChild, toCatalyst(s.direction), toCatalyst(s.nullOrdering), Seq.empty) - case it: IdentityTransform => - resolve(ArraySeq.unsafeWrapArray(it.ref.fieldNames)) - case BucketTransform(numBuckets, ref) => - IcebergBucketTransform(numBuckets, resolve(ArraySeq.unsafeWrapArray(ref.fieldNames))) - case TruncateTransform(ref, width) => - IcebergTruncateTransform(resolve(ArraySeq.unsafeWrapArray(ref.fieldNames)), width) - case yt: YearsTransform => - IcebergYearTransform(resolve(ArraySeq.unsafeWrapArray(yt.ref.fieldNames))) - case mt: MonthsTransform => - IcebergMonthTransform(resolve(ArraySeq.unsafeWrapArray(mt.ref.fieldNames))) - case dt: DaysTransform => - IcebergDayTransform(resolve(ArraySeq.unsafeWrapArray(dt.ref.fieldNames))) - case ht: HoursTransform => - IcebergHourTransform(resolve(ArraySeq.unsafeWrapArray(ht.ref.fieldNames))) - case ref: FieldReference => - resolve(ArraySeq.unsafeWrapArray(ref.fieldNames)) - case _ => - throw new RuntimeException(s"$expr is not currently supported") - - } - } - - private def toCatalyst(direction: SortDirection): catalyst.expressions.SortDirection = { - direction match { - case SortDirection.ASCENDING => catalyst.expressions.Ascending - case SortDirection.DESCENDING => catalyst.expressions.Descending - } - } - - private def toCatalyst(nullOrdering: NullOrdering): catalyst.expressions.NullOrdering = { - nullOrdering match { - case NullOrdering.NULLS_FIRST => catalyst.expressions.NullsFirst - case NullOrdering.NULLS_LAST => catalyst.expressions.NullsLast - } - } - - private object BucketTransform { - def unapply(transform: Transform): Option[(Int, FieldReference)] = transform match { - case bt: BucketTransform => bt.columns match { - case Seq(nf: NamedReference) => - Some(bt.numBuckets.value(), FieldReference(ArraySeq.unsafeWrapArray(nf.fieldNames()))) - case _ => - None - } - case _ => None - } - } - - private object Lit { - def unapply[T](literal: Literal[T]): Some[(T, DataType)] = { - Some((literal.value, literal.dataType)) - } - } - - private object TruncateTransform { - def unapply(transform: Transform): Option[(FieldReference, Int)] = transform match { - case at @ ApplyTransform(name, _) if name.equalsIgnoreCase("truncate") => at.args match { - case Seq(nf: NamedReference, Lit(value: Int, IntegerType)) => - Some(FieldReference(ArraySeq.unsafeWrapArray(nf.fieldNames())), value) - case Seq(Lit(value: Int, IntegerType), nf: NamedReference) => - Some(FieldReference(ArraySeq.unsafeWrapArray(nf.fieldNames())), value) - case _ => - None - } - case _ => None - } - } -} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java index 536dd5febbaa..d91ac3606d97 100644 --- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java @@ -463,9 +463,7 @@ public void testManyTopLevelPartitions() throws InterruptedException { "Should not delete any files", Iterables.isEmpty(result.orphanFileLocations())); Dataset resultDF = spark.read().format("iceberg").load(tableLocation); - List actualRecords = - resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); - Assert.assertEquals("Rows must match", records, actualRecords); + Assert.assertEquals("Rows count must match", records.size(), resultDF.count()); } @Test @@ -492,9 +490,7 @@ public void testManyLeafPartitions() throws InterruptedException { "Should not delete any files", Iterables.isEmpty(result.orphanFileLocations())); Dataset resultDF = spark.read().format("iceberg").load(tableLocation); - List actualRecords = - resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); - Assert.assertEquals("Rows must match", records, actualRecords); + Assert.assertEquals("Row count must match", records.size(), resultDF.count()); } @Test diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java index 761284bb56ea..3ecd7ce37138 100644 --- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java @@ -93,6 +93,7 @@ import org.apache.iceberg.spark.ScanTaskSetManager; import org.apache.iceberg.spark.SparkTableUtil; import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.SparkWriteOptions; import org.apache.iceberg.spark.actions.RewriteDataFilesSparkAction.RewriteExecutionContext; import org.apache.iceberg.spark.source.ThreeColumnRecord; import org.apache.iceberg.types.Comparators; @@ -102,8 +103,10 @@ import org.apache.iceberg.util.Pair; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; import org.junit.Assert; import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -127,6 +130,12 @@ public class TestRewriteDataFilesAction extends SparkTestBase { private final ScanTaskSetManager manager = ScanTaskSetManager.get(); private String tableLocation = null; + @BeforeClass + public static void setupSpark() { + // disable AQE as tests assume that writes generate a particular number of files + spark.conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"); + } + @Before public void setupTableLocation() throws Exception { File tableDir = temp.newFolder(); @@ -1630,6 +1639,7 @@ private void writeDF(Dataset df) { .write() .format("iceberg") .mode("append") + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") .save(tableLocation); } diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java index 521d90299d2b..ac481ca473bb 100644 --- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java @@ -215,7 +215,7 @@ public void testHashDistribution() throws NoSuchTableException { } @Test - public void testNoSortBucketTransformsWithoutExtensions() throws NoSuchTableException { + public void testSortBucketTransformsWithoutExtensions() throws NoSuchTableException { sql( "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + "USING iceberg " @@ -231,20 +231,7 @@ public void testNoSortBucketTransformsWithoutExtensions() throws NoSuchTableExce Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); - // should fail by default as extensions are disabled - AssertHelpers.assertThrowsCause( - "Should reject writes without ordering", - IllegalStateException.class, - "Incoming records violate the writer assumption", - () -> { - try { - inputDF.writeTo(tableName).append(); - } catch (NoSuchTableException e) { - throw new RuntimeException(e); - } - }); - - inputDF.writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + inputDF.writeTo(tableName).append(); List expected = ImmutableList.of(