diff --git a/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 b/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 index d8128d39052e..d1ab06f852c8 100644 --- a/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 +++ b/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 @@ -73,13 +73,33 @@ statement | ALTER TABLE multipartIdentifier WRITE writeSpec #setWriteDistributionAndOrdering | ALTER TABLE multipartIdentifier SET IDENTIFIER_KW FIELDS fieldList #setIdentifierFields | ALTER TABLE multipartIdentifier DROP IDENTIFIER_KW FIELDS fieldList #dropIdentifierFields - | ALTER TABLE multipartIdentifier CREATE BRANCH identifier (AS OF VERSION snapshotId)? (RETAIN snapshotRefRetain snapshotRefRetainTimeUnit)? (snapshotRetentionClause)? #createBranch + | ALTER TABLE multipartIdentifier createReplaceBranchClause #createOrReplaceBranch ; -snapshotRetentionClause - : WITH SNAPSHOT RETENTION numSnapshots SNAPSHOTS - | WITH SNAPSHOT RETENTION snapshotRetain snapshotRetainTimeUnit - | WITH SNAPSHOT RETENTION numSnapshots SNAPSHOTS snapshotRetain snapshotRetainTimeUnit +createReplaceBranchClause + : (CREATE OR)? REPLACE BRANCH identifier branchOptions + | CREATE BRANCH (IF NOT EXISTS)? identifier branchOptions + ; + +branchOptions + : (AS OF VERSION snapshotId)? (refRetain)? (snapshotRetention)?; + +snapshotRetention + : WITH SNAPSHOT RETENTION minSnapshotsToKeep + | WITH SNAPSHOT RETENTION maxSnapshotAge + | WITH SNAPSHOT RETENTION minSnapshotsToKeep maxSnapshotAge + ; + +refRetain + : RETAIN number timeUnit + ; + +maxSnapshotAge + : number timeUnit + ; + +minSnapshotsToKeep + : number SNAPSHOTS ; writeSpec @@ -175,7 +195,7 @@ fieldList ; nonReserved - : ADD | ALTER | AS | ASC | BRANCH | BY | CALL | CREATE | DAYS | DESC | DROP | FIELD | FIRST | HOURS | LAST | NULLS | OF | ORDERED | PARTITION | TABLE | WRITE + : ADD | ALTER | AS | ASC | BRANCH | BY | CALL | CREATE | DAYS | DESC | DROP | EXISTS | FIELD | FIRST | HOURS | IF | LAST | NOT | NULLS | OF | OR | ORDERED | PARTITION | TABLE | WRITE | DISTRIBUTED | LOCALLY | MINUTES | MONTHS | UNORDERED | REPLACE | RETAIN | VERSION | WITH | IDENTIFIER_KW | FIELDS | SET | SNAPSHOT | SNAPSHOTS | TRUE | FALSE | MAP @@ -189,22 +209,6 @@ numSnapshots : number ; -snapshotRetain - : number - ; - -snapshotRefRetain - : number - ; - -snapshotRefRetainTimeUnit - : timeUnit - ; - -snapshotRetainTimeUnit - : timeUnit - ; - timeUnit : DAYS | HOURS @@ -222,17 +226,21 @@ DAYS: 'DAYS'; DESC: 'DESC'; DISTRIBUTED: 'DISTRIBUTED'; DROP: 'DROP'; +EXISTS: 'EXISTS'; FIELD: 'FIELD'; FIELDS: 'FIELDS'; FIRST: 'FIRST'; HOURS: 'HOURS'; +IF : 'IF'; LAST: 'LAST'; LOCALLY: 'LOCALLY'; MINUTES: 'MINUTES'; MONTHS: 'MONTHS'; CREATE: 'CREATE'; +NOT: 'NOT'; NULLS: 'NULLS'; OF: 'OF'; +OR: 'OR'; ORDERED: 'ORDERED'; PARTITION: 'PARTITION'; REPLACE: 'REPLACE'; diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala index 4c059f7c343b..76af7d1ec608 100644 --- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala @@ -206,7 +206,8 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserI normalized.contains("write unordered") || normalized.contains("set identifier fields") || normalized.contains("drop identifier fields") || - normalized.contains("create branch"))) + normalized.contains("create branch"))) || + normalized.contains("replace branch") } diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala index 950e161f9f99..d6564d6ab927 100644 --- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala @@ -37,9 +37,10 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.parser.extensions.IcebergParserUtils.withOrigin import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser._ import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField +import org.apache.spark.sql.catalyst.plans.logical.BranchOptions import org.apache.spark.sql.catalyst.plans.logical.CallArgument import org.apache.spark.sql.catalyst.plans.logical.CallStatement -import org.apache.spark.sql.catalyst.plans.logical.CreateBranch +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceBranch import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -91,25 +92,40 @@ class IcebergSqlExtensionsAstBuilder(delegate: ParserInterface) extends IcebergS typedVisit[Transform](ctx.transform)) } - /** - * Create an ADD BRANCH logical command. - */ - override def visitCreateBranch(ctx: CreateBranchContext): CreateBranch = withOrigin(ctx) { - val snapshotRetention = Option(ctx.snapshotRetentionClause()) - - CreateBranch( + override def visitCreateOrReplaceBranch(ctx: CreateOrReplaceBranchContext): CreateOrReplaceBranch = withOrigin(ctx) { + val createOrReplaceBranchClause = ctx.createReplaceBranchClause() + + val branchName = createOrReplaceBranchClause.identifier() + val branchOptionsContext = Option(createOrReplaceBranchClause.branchOptions()) + val snapshotId = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.snapshotId())) + .map(_.getText.toLong) + val snapshotRetention = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.snapshotRetention())) + val minSnapshotsToKeep = snapshotRetention.flatMap(retention => Option(retention.minSnapshotsToKeep())) + .map(minSnapshots => minSnapshots.number().getText.toLong) + val maxSnapshotAgeMs = snapshotRetention + .flatMap(retention => Option(retention.maxSnapshotAge())) + .map(retention => TimeUnit.valueOf(retention.timeUnit().getText.toUpperCase(Locale.ENGLISH)) + .toMillis(retention.number().getText.toLong)) + val branchRetention = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.refRetain())) + val branchRefAgeMs = branchRetention.map(retain => + TimeUnit.valueOf(retain.timeUnit().getText.toUpperCase(Locale.ENGLISH)).toMillis(retain.number().getText.toLong)) + val replace = ctx.createReplaceBranchClause().REPLACE() != null + val ifNotExists = createOrReplaceBranchClause.EXISTS() != null + + val branchOptions = BranchOptions( + snapshotId, + minSnapshotsToKeep, + maxSnapshotAgeMs, + branchRefAgeMs + ) + + CreateOrReplaceBranch( typedVisit[Seq[String]](ctx.multipartIdentifier), - ctx.identifier().getText, - Option(ctx.snapshotId()).map(_.getText.toLong), - snapshotRetention.flatMap(s => Option(s.numSnapshots())).map(_.getText.toLong), - snapshotRetention.flatMap(s => Option(s.snapshotRetain())).map(retain => { - TimeUnit.valueOf(ctx.snapshotRetentionClause().snapshotRetainTimeUnit().getText.toUpperCase(Locale.ENGLISH)) - .toMillis(retain.getText.toLong) - }), - Option(ctx.snapshotRefRetain()).map(retain => { - TimeUnit.valueOf(ctx.snapshotRefRetainTimeUnit().getText.toUpperCase(Locale.ENGLISH)) - .toMillis(retain.getText.toLong) - })) + branchName.getText, + branchOptions, + replace, + ifNotExists) + } /** diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala new file mode 100644 index 000000000000..4d7e0a086bda --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +case class BranchOptions (snapshotId: Option[Long], numSnapshots: Option[Long], + snapshotRetain: Option[Long], snapshotRefRetain: Option[Long]) diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala similarity index 79% rename from spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala rename to spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala index 91e2bc6f1951..24d6bd3d9123 100644 --- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala @@ -21,14 +21,15 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.Attribute -case class CreateBranch(table: Seq[String], branch: String, snapshotId: Option[Long], numSnapshots: Option[Long], - snapshotRetain: Option[Long], snapshotRefRetain: Option[Long]) extends LeafCommand { +case class CreateOrReplaceBranch(table: Seq[String], branch: String, + branchOptions: BranchOptions, replace: Boolean, ifNotExists: Boolean) + extends LeafCommand { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override lazy val output: Seq[Attribute] = Nil override def simpleString(maxFields: Int): String = { - s"Create branch: ${branch} for table: ${table.quoted} " + s"CreateOrReplaceBranch branch: ${branch} for table: ${table.quoted}" } } diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateBranchExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateBranchExec.scala deleted file mode 100644 index acaab93b0bd0..000000000000 --- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateBranchExec.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.iceberg.spark.source.SparkTable -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.CreateBranch -import org.apache.spark.sql.connector.catalog.Identifier -import org.apache.spark.sql.connector.catalog.TableCatalog - -case class CreateBranchExec( - catalog: TableCatalog, - ident: Identifier, - createBranch: CreateBranch) extends LeafV2CommandExec { - - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - - override lazy val output: Seq[Attribute] = Nil - - override protected def run(): Seq[InternalRow] = { - catalog.loadTable(ident) match { - case iceberg: SparkTable => - - val snapshotId = createBranch.snapshotId.getOrElse(iceberg.table.currentSnapshot().snapshotId()) - val manageSnapshot = iceberg.table.manageSnapshots() - .createBranch(createBranch.branch, snapshotId) - - if (createBranch.numSnapshots.nonEmpty) { - manageSnapshot.setMinSnapshotsToKeep(createBranch.branch, createBranch.numSnapshots.get.toInt) - } - - if (createBranch.snapshotRetain.nonEmpty) { - manageSnapshot.setMaxSnapshotAgeMs(createBranch.branch, createBranch.snapshotRetain.get) - } - - if (createBranch.snapshotRefRetain.nonEmpty) { - manageSnapshot.setMaxRefAgeMs(createBranch.branch, createBranch.snapshotRefRetain.get) - } - - manageSnapshot.commit() - - case table => - throw new UnsupportedOperationException(s"Cannot add branch to non-Iceberg table: $table") - } - - Nil - } - - override def simpleString(maxFields: Int): String = { - s"Create branch: ${createBranch.branch} operation for table: ${ident.quoted}" - } -} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala new file mode 100644 index 000000000000..08230afb5a3f --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.BranchOptions +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class CreateOrReplaceBranchExec( + catalog: TableCatalog, + ident: Identifier, + branch: String, + branchOptions: BranchOptions, + replace: Boolean, + ifNotExists: Boolean) extends LeafV2CommandExec { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val snapshotId = branchOptions.snapshotId.getOrElse(iceberg.table.currentSnapshot().snapshotId()) + val manageSnapshots = iceberg.table().manageSnapshots() + if (!replace) { + val ref = iceberg.table().refs().get(branch); + if (ref != null && ifNotExists) { + return Nil + } + + manageSnapshots.createBranch(branch, snapshotId) + } else { + manageSnapshots.replaceBranch(branch, snapshotId) + } + + if (branchOptions.numSnapshots.nonEmpty) { + manageSnapshots.setMinSnapshotsToKeep(branch, branchOptions.numSnapshots.get.toInt) + } + + if (branchOptions.snapshotRetain.nonEmpty) { + manageSnapshots.setMaxSnapshotAgeMs(branch, branchOptions.snapshotRetain.get) + } + + if (branchOptions.snapshotRefRetain.nonEmpty) { + manageSnapshots.setMaxRefAgeMs(branch, branchOptions.snapshotRefRetain.get) + } + + manageSnapshots.commit() + + case table => + throw new UnsupportedOperationException(s"Cannot create or replace branch on non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"CreateOrReplace branch: ${branch} for table: ${ident.quoted}" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala index 08c1c1dae61d..7e343534dede 100644 --- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField import org.apache.spark.sql.catalyst.plans.logical.Call -import org.apache.spark.sql.catalyst.plans.logical.CreateBranch +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceBranch import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField @@ -62,8 +62,9 @@ case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy wi case AddPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform, name) => AddPartitionFieldExec(catalog, ident, transform, name) :: Nil - case CreateBranch(IcebergCatalogAndIdentifier(catalog, ident), _, _, _, _, _) => - CreateBranchExec(catalog, ident, plan.asInstanceOf[CreateBranch]) :: Nil + case CreateOrReplaceBranch( + IcebergCatalogAndIdentifier(catalog, ident), branch, branchOptions, replace, ifNotExists) => + CreateOrReplaceBranchExec(catalog, ident, branch, branchOptions, replace, ifNotExists) :: Nil case DropPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform) => DropPartitionFieldExec(catalog, ident, transform) :: Nil diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateBranch.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateBranch.java index 0379bcf7a91d..42d34779ee63 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateBranch.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateBranch.java @@ -72,7 +72,7 @@ public void testCreateBranch() throws NoSuchTableException { tableName, branchName, snapshotId, maxRefAge, minSnapshotsToKeep, maxSnapshotAge); table.refresh(); SnapshotRef ref = table.refs().get(branchName); - Assert.assertNotNull(ref); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); Assert.assertEquals(TimeUnit.DAYS.toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); Assert.assertEquals(TimeUnit.DAYS.toMillis(maxRefAge), ref.maxRefAgeMs().longValue()); @@ -91,7 +91,7 @@ public void testCreateBranchUseDefaultConfig() throws NoSuchTableException { sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branchName); table.refresh(); SnapshotRef ref = table.refs().get(branchName); - Assert.assertNotNull(ref); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); Assert.assertNull(ref.minSnapshotsToKeep()); Assert.assertNull(ref.maxSnapshotAgeMs()); Assert.assertNull(ref.maxRefAgeMs()); @@ -107,7 +107,7 @@ public void testCreateBranchUseCustomMinSnapshotsToKeep() throws NoSuchTableExce tableName, branchName, minSnapshotsToKeep); table.refresh(); SnapshotRef ref = table.refs().get(branchName); - Assert.assertNotNull(ref); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); Assert.assertNull(ref.maxSnapshotAgeMs()); Assert.assertNull(ref.maxRefAgeMs()); @@ -129,6 +129,24 @@ public void testCreateBranchUseCustomMaxSnapshotAge() throws NoSuchTableExceptio Assert.assertNull(ref.maxRefAgeMs()); } + @Test + public void testCreateBranchIfNotExists() throws NoSuchTableException { + long maxSnapshotAge = 2L; + Table table = createDefaultTableAndInsert2Row(); + String branchName = "b1"; + sql( + "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION %d DAYS", + tableName, branchName, maxSnapshotAge); + sql("ALTER TABLE %s CREATE BRANCH IF NOT EXISTS %s", tableName, branchName); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); + Assert.assertNull(ref.minSnapshotsToKeep()); + Assert.assertEquals(TimeUnit.DAYS.toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); + Assert.assertNull(ref.maxRefAgeMs()); + } + @Test public void testCreateBranchUseCustomMinSnapshotsToKeepAndMaxSnapshotAge() throws NoSuchTableException { @@ -141,7 +159,7 @@ public void testCreateBranchUseCustomMinSnapshotsToKeepAndMaxSnapshotAge() tableName, branchName, minSnapshotsToKeep, maxSnapshotAge); table.refresh(); SnapshotRef ref = table.refs().get(branchName); - Assert.assertNotNull(ref); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); Assert.assertEquals(TimeUnit.DAYS.toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); Assert.assertNull(ref.maxRefAgeMs()); @@ -162,7 +180,7 @@ public void testCreateBranchUseCustomMaxRefAge() throws NoSuchTableException { sql("ALTER TABLE %s CREATE BRANCH %s RETAIN %d DAYS", tableName, branchName, maxRefAge); table.refresh(); SnapshotRef ref = table.refs().get(branchName); - Assert.assertNotNull(ref); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); Assert.assertNull(ref.minSnapshotsToKeep()); Assert.assertNull(ref.maxSnapshotAgeMs()); Assert.assertEquals(TimeUnit.DAYS.toMillis(maxRefAge), ref.maxRefAgeMs().longValue()); diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java new file mode 100644 index 000000000000..f97a95ff82a7 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runners.Parameterized; + +public class TestReplaceBranch extends SparkExtensionsTestBase { + + private static final String[] TIME_UNITS = {"DAYS", "HOURS", "MINUTES"}; + + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + } + }; + } + + public TestReplaceBranch(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testReplaceBranchFailsForTag() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + String tagName = "tag1"; + + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createTag(tagName, first).commit(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + AssertHelpers.assertThrows( + "Cannot perform replace branch on tags", + IllegalArgumentException.class, + "Ref tag1 is a tag not a branch", + () -> sql("ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", tableName, tagName, second)); + } + + @Test + public void testReplaceBranch() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + long expectedMaxRefAgeMs = 1000; + int expectedMinSnapshotsToKeep = 2; + long expectedMaxSnapshotAgeMs = 1000; + table + .manageSnapshots() + .createBranch(branchName, first) + .setMaxRefAgeMs(branchName, expectedMaxRefAgeMs) + .setMinSnapshotsToKeep(branchName, expectedMinSnapshotsToKeep) + .setMaxSnapshotAgeMs(branchName, expectedMaxSnapshotAgeMs) + .commit(); + + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + sql("ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", tableName, branchName, second); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertEquals(ref.snapshotId(), second); + Assert.assertEquals(expectedMinSnapshotsToKeep, ref.minSnapshotsToKeep().intValue()); + Assert.assertEquals(expectedMaxSnapshotAgeMs, ref.maxSnapshotAgeMs().longValue()); + Assert.assertEquals(expectedMaxRefAgeMs, ref.maxRefAgeMs().longValue()); + } + + @Test + public void testReplaceBranchDoesNotExist() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + Table table = validationCatalog.loadTable(tableIdent); + + AssertHelpers.assertThrows( + "Cannot perform replace branch on branch which does not exist", + IllegalArgumentException.class, + "Branch does not exist", + () -> + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", + tableName, "someBranch", table.currentSnapshot().snapshotId())); + } + + @Test + public void testReplaceBranchWithRetain() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + table.manageSnapshots().createBranch(branchName, first).commit(); + SnapshotRef b1 = table.refs().get(branchName); + Integer minSnapshotsToKeep = b1.minSnapshotsToKeep(); + Long maxSnapshotAgeMs = b1.maxSnapshotAgeMs(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + long maxRefAge = 10; + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d RETAIN %d %s", + tableName, branchName, second, maxRefAge, timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertEquals(ref.snapshotId(), second); + Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); + Assert.assertEquals(maxSnapshotAgeMs, ref.maxSnapshotAgeMs()); + Assert.assertEquals( + TimeUnit.valueOf(timeUnit).toMillis(maxRefAge), ref.maxRefAgeMs().longValue()); + } + } + + @Test + public void testReplaceBranchWithSnapshotRetention() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + String branchName = "b1"; + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch(branchName, first).commit(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + Integer minSnapshotsToKeep = 2; + long maxSnapshotAge = 2; + Long maxRefAgeMs = table.refs().get(branchName).maxRefAgeMs(); + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d WITH SNAPSHOT RETENTION %d SNAPSHOTS %d %s", + tableName, branchName, second, minSnapshotsToKeep, maxSnapshotAge, timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertEquals(ref.snapshotId(), second); + Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); + Assert.assertEquals( + TimeUnit.valueOf(timeUnit).toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); + Assert.assertEquals(maxRefAgeMs, ref.maxRefAgeMs()); + } + } + + @Test + public void testReplaceBranchWithRetainAndSnapshotRetention() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + table.manageSnapshots().createBranch(branchName, first).commit(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + Integer minSnapshotsToKeep = 2; + long maxSnapshotAge = 2; + long maxRefAge = 10; + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d RETAIN %d %s WITH SNAPSHOT RETENTION %d SNAPSHOTS %d %s", + tableName, + branchName, + second, + maxRefAge, + timeUnit, + minSnapshotsToKeep, + maxSnapshotAge, + timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertEquals(ref.snapshotId(), second); + Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); + Assert.assertEquals( + TimeUnit.valueOf(timeUnit).toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); + Assert.assertEquals( + TimeUnit.valueOf(timeUnit).toMillis(maxRefAge), ref.maxRefAgeMs().longValue()); + } + } + + @Test + public void testCreateOrReplace() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch(branchName, second).commit(); + + sql( + "ALTER TABLE %s CREATE OR REPLACE BRANCH %s AS OF VERSION %d", + tableName, branchName, first); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertEquals(ref.snapshotId(), first); + } +}