From bffd3ff2469ea9e4b0a671f0a48e198e2ddcfe04 Mon Sep 17 00:00:00 2001 From: Reetika Agrawal Date: Mon, 16 Mar 2026 12:44:56 +0530 Subject: [PATCH] Fix thread safety in Iceberg procedures --- .../procedure/FastForwardBranchProcedure.java | 13 ++++++++----- .../RollbackToSnapshotProcedure.java | 13 ++++++++----- .../SetCurrentSnapshotProcedure.java | 19 +++++++++++-------- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java index d1eba998e5c0f..1b0481d006935 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java @@ -17,6 +17,7 @@ import com.facebook.presto.iceberg.transaction.IcebergTransactionMetadata; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; @@ -68,10 +69,12 @@ public Procedure get() public void fastForwardToBranch(ConnectorSession clientSession, String schemaName, String tableName, String fromBranch, String targetBranch) { - SchemaTableName schemaTableName = new SchemaTableName(schemaName, tableName); - IcebergTransactionMetadata metadata = metadataFactory.create(); - Table icebergTable = getIcebergTable(metadata, clientSession, schemaTableName); - icebergTable.manageSnapshots().fastForwardBranch(fromBranch, targetBranch).commit(); - metadata.commit(); + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + SchemaTableName schemaTableName = new SchemaTableName(schemaName, tableName); + IcebergTransactionMetadata metadata = metadataFactory.create(); + Table icebergTable = getIcebergTable(metadata, clientSession, schemaTableName); + icebergTable.manageSnapshots().fastForwardBranch(fromBranch, targetBranch).commit(); + metadata.commit(); + } } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java index 2462e49b7c336..a5701951b6850 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java @@ -17,6 +17,7 @@ import com.facebook.presto.iceberg.transaction.IcebergTransactionMetadata; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; @@ -66,10 +67,12 @@ public Procedure get() public void rollbackToSnapshot(ConnectorSession clientSession, String schema, String table, Long snapshotId) { - SchemaTableName schemaTableName = new SchemaTableName(schema, table); - IcebergTransactionMetadata metadata = metadataFactory.create(); - getIcebergTable(metadata, clientSession, schemaTableName) - .manageSnapshots().rollbackTo(snapshotId).commit(); - metadata.commit(); + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + SchemaTableName schemaTableName = new SchemaTableName(schema, table); + IcebergTransactionMetadata metadata = metadataFactory.create(); + getIcebergTable(metadata, clientSession, schemaTableName) + .manageSnapshots().rollbackTo(snapshotId).commit(); + metadata.commit(); + } } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java index 6413273f0599c..10b66b7c024d8 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java @@ -17,6 +17,7 @@ import com.facebook.presto.iceberg.transaction.IcebergTransactionMetadata; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; @@ -71,14 +72,16 @@ public Procedure get() public void setCurrentSnapshot(ConnectorSession clientSession, String schema, String table, Long snapshotId, String reference) { - checkState((snapshotId != null && reference == null) || (snapshotId == null && reference != null), - "Either snapshot_id or reference must be provided, not both"); - SchemaTableName schemaTableName = new SchemaTableName(schema, table); - IcebergTransactionMetadata metadata = metadataFactory.create(); - Table icebergTable = getIcebergTable(metadata, clientSession, schemaTableName); - long targetSnapshotId = snapshotId != null ? snapshotId : getSnapshotIdFromReference(icebergTable, reference); - icebergTable.manageSnapshots().setCurrentSnapshot(targetSnapshotId).commit(); - metadata.commit(); + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + checkState((snapshotId != null && reference == null) || (snapshotId == null && reference != null), + "Either snapshot_id or reference must be provided, not both"); + SchemaTableName schemaTableName = new SchemaTableName(schema, table); + IcebergTransactionMetadata metadata = metadataFactory.create(); + Table icebergTable = getIcebergTable(metadata, clientSession, schemaTableName); + long targetSnapshotId = snapshotId != null ? snapshotId : getSnapshotIdFromReference(icebergTable, reference); + icebergTable.manageSnapshots().setCurrentSnapshot(targetSnapshotId).commit(); + metadata.commit(); + } } private long getSnapshotIdFromReference(Table table, String refName)