diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java index c6a3242f0b7ff..b851cdc81146a 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java @@ -57,7 +57,6 @@ public void configure(Binder binder) { binder.bind(CassandraConnectorId.class).toInstance(new CassandraConnectorId(connectorId)); binder.bind(CassandraConnector.class).in(Scopes.SINGLETON); - binder.bind(CassandraMetadata.class).in(Scopes.SINGLETON); binder.bind(CassandraSplitManager.class).in(Scopes.SINGLETON); binder.bind(CassandraTokenSplitManager.class).in(Scopes.SINGLETON); binder.bind(CassandraRecordSetProvider.class).in(Scopes.SINGLETON); diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnector.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnector.java index 016f2c8022465..9562a9afb6971 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnector.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnector.java @@ -14,8 +14,10 @@ package com.facebook.presto.cassandra; import com.facebook.airlift.bootstrap.LifeCycleManager; +import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.log.Logger; import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorCommitHandle; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; @@ -26,9 +28,13 @@ import jakarta.inject.Inject; import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import static com.facebook.presto.spi.connector.EmptyConnectorCommitHandle.INSTANCE; import static com.facebook.presto.spi.transaction.IsolationLevel.READ_UNCOMMITTED; import static com.facebook.presto.spi.transaction.IsolationLevel.checkConnectorSupports; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public class CassandraConnector @@ -36,35 +42,66 @@ public class CassandraConnector { private static final Logger log = Logger.get(CassandraConnector.class); + private final CassandraConnectorId connectorId; private final LifeCycleManager lifeCycleManager; - private final CassandraMetadata metadata; + private final CassandraPartitionManager partitionManager; + private final CassandraClientConfig config; + private final CassandraSession cassandraSession; private final CassandraSplitManager splitManager; private final ConnectorRecordSetProvider recordSetProvider; private final ConnectorPageSinkProvider pageSinkProvider; private final List> sessionProperties; + private final JsonCodec> extraColumnMetadataCodec; + private final ConcurrentMap transactions = new ConcurrentHashMap<>(); @Inject public CassandraConnector( + CassandraConnectorId connectorId, LifeCycleManager lifeCycleManager, - CassandraMetadata metadata, CassandraSplitManager splitManager, CassandraRecordSetProvider recordSetProvider, CassandraPageSinkProvider pageSinkProvider, - CassandraSessionProperties sessionProperties) + CassandraSessionProperties sessionProperties, + CassandraSession cassandraSession, + CassandraPartitionManager partitionManager, + JsonCodec> extraColumnMetadataCodec, + CassandraClientConfig config) { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); - this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.recordSetProvider = requireNonNull(recordSetProvider, "recordSetProvider is null"); this.pageSinkProvider = requireNonNull(pageSinkProvider, "pageSinkProvider is null"); this.sessionProperties = requireNonNull(sessionProperties.getSessionProperties(), "sessionProperties is null"); + this.partitionManager = requireNonNull(partitionManager, "partitionManager is null"); + this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); + this.config = requireNonNull(config, "config is null"); + this.extraColumnMetadataCodec = requireNonNull(extraColumnMetadataCodec, "extraColumnMetadataCodec is null"); } @Override public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) { checkConnectorSupports(READ_UNCOMMITTED, isolationLevel); - return CassandraTransactionHandle.INSTANCE; + CassandraTransactionHandle transaction = new CassandraTransactionHandle(); + transactions.put(transaction, + new CassandraMetadata(connectorId, cassandraSession, partitionManager, extraColumnMetadataCodec, config)); + return transaction; + } + + @Override + public ConnectorCommitHandle commit(ConnectorTransactionHandle transaction) + { + checkArgument(transactions.remove(transaction) != null, "no such transaction: %s", transaction); + return INSTANCE; + } + + @Override + public void rollback(ConnectorTransactionHandle transaction) + { + CassandraMetadata metadata = transactions.remove(transaction); + checkArgument(metadata != null, "no such transaction: %s", transaction); + metadata.rollback(); } @Override @@ -74,8 +111,10 @@ public boolean isSingleStatementWritesOnly() } @Override - public ConnectorMetadata getMetadata(ConnectorTransactionHandle transactionHandle) + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transaction) { + CassandraMetadata metadata = transactions.get(transaction); + checkArgument(metadata != null, "no such transaction: %s", transaction); return metadata; } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java index da86b2293fdcc..b4ea65a7ecb48 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java @@ -41,13 +41,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; -import jakarta.inject.Inject; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import static com.facebook.presto.cassandra.CassandraType.toCassandraType; @@ -57,6 +57,7 @@ import static com.facebook.presto.spi.StandardErrorCode.PERMISSION_DENIED; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; @@ -72,8 +73,8 @@ public class CassandraMetadata private boolean caseSensitiveNameMatchingEnabled; private final JsonCodec> extraColumnMetadataCodec; + private final AtomicReference rollbackAction = new AtomicReference<>(); - @Inject public CassandraMetadata( CassandraConnectorId connectorId, CassandraSession cassandraSession, @@ -319,6 +320,9 @@ private CassandraOutputTableHandle createTable(ConnectorSession session, Connect // We need to create the Cassandra table before commit because the record needs to be written to the table. cassandraSession.execute(queryBuilder.toString()); + + // set a rollback to delete the created table in case of an abort / failure. + setRollback(schemaName, tableName); return new CassandraOutputTableHandle( connectorId, schemaName, @@ -330,6 +334,7 @@ private CassandraOutputTableHandle createTable(ConnectorSession session, Connect @Override public Optional finishCreateTable(ConnectorSession session, ConnectorOutputTableHandle tableHandle, Collection fragments, Collection computedStatistics) { + clearRollback(); return Optional.empty(); } @@ -365,4 +370,30 @@ public String normalizeIdentifier(ConnectorSession session, String identifier) { return caseSensitiveNameMatchingEnabled ? identifier : identifier.toLowerCase(ROOT); } + + public void rollback() + { + Runnable action = rollbackAction.getAndSet(null); + if (action == null) { + return; // nothing to roll back + } + + if (!allowDropTable) { + throw new PrestoException( + PERMISSION_DENIED, + "Table creation was aborted and requires rollback, but cleanup failed because DROP TABLE is disabled in this Cassandra catalog."); + } + + action.run(); + } + + private void setRollback(String schemaName, String tableName) + { + checkState(rollbackAction.compareAndSet(null, () -> cassandraSession.execute(String.format("DROP TABLE \"%s\".\"%s\"", schemaName, tableName))), "rollback action is already set"); + } + + private void clearRollback() + { + rollbackAction.set(null); + } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTransactionHandle.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTransactionHandle.java index 7a2eb23d4f162..4128e287135ef 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTransactionHandle.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTransactionHandle.java @@ -14,9 +14,61 @@ package com.facebook.presto.cassandra; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; -public enum CassandraTransactionHandle +import java.util.Objects; +import java.util.UUID; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class CassandraTransactionHandle implements ConnectorTransactionHandle { - INSTANCE + private final UUID uuid; + + public CassandraTransactionHandle() + { + this(UUID.randomUUID()); + } + + @JsonCreator + public CassandraTransactionHandle(@JsonProperty("uuid") UUID uuid) + { + this.uuid = requireNonNull(uuid, "uuid is null"); + } + + @JsonProperty + public UUID getUuid() + { + return uuid; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + CassandraTransactionHandle other = (CassandraTransactionHandle) obj; + return Objects.equals(uuid, other.uuid); + } + + @Override + public int hashCode() + { + return Objects.hash(uuid); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("uuid", uuid) + .toString(); + } } diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraConnector.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraConnector.java index 49a8a2bdc247d..a3535c0e46b4f 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraConnector.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraConnector.java @@ -17,6 +17,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorSplitSource; @@ -66,10 +67,12 @@ import static com.facebook.presto.common.type.Varchars.isVarcharType; import static com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING; import static com.facebook.presto.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; +import static com.facebook.presto.spi.transaction.IsolationLevel.READ_UNCOMMITTED; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Locale.ENGLISH; import static java.util.Locale.ROOT; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -98,10 +101,11 @@ public class TestCassandraConnector protected SchemaTableName table; protected SchemaTableName tableUnpartitioned; protected SchemaTableName invalidTable; + protected SchemaTableName rollbackTable; private CassandraServer server; - private ConnectorMetadata metadata; private ConnectorSplitManager splitManager; private ConnectorRecordSetProvider recordSetProvider; + private Connector connector; @BeforeClass public void setup() @@ -115,14 +119,12 @@ public void setup() String connectorId = "cassandra-test"; CassandraConnectorFactory connectorFactory = new CassandraConnectorFactory(connectorId); - Connector connector = connectorFactory.create(connectorId, ImmutableMap.of( - "cassandra.contact-points", server.getHost(), - "cassandra.native-protocol-port", Integer.toString(server.getPort())), + connector = connectorFactory.create(connectorId, ImmutableMap.of( + "cassandra.contact-points", server.getHost(), + "cassandra.native-protocol-port", Integer.toString(server.getPort()), + "cassandra.allow-drop-table", "true"), new TestingConnectorContext()); - metadata = connector.getMetadata(CassandraTransactionHandle.INSTANCE); - assertInstanceOf(metadata, CassandraMetadata.class); - splitManager = connector.getSplitManager(); assertInstanceOf(splitManager, CassandraSplitManager.class); @@ -133,6 +135,7 @@ public void setup() table = new SchemaTableName(database, TABLE_ALL_TYPES.toLowerCase(ROOT)); tableUnpartitioned = new SchemaTableName(database, "presto_test_unpartitioned"); invalidTable = new SchemaTableName(database, "totally_invalid_table_name"); + rollbackTable = new SchemaTableName(database, "rollback_table"); } @Test @@ -149,6 +152,8 @@ public void tearDown() @Test public void testGetDatabaseNames() { + ConnectorTransactionHandle transactionHandle = connector.beginTransaction(READ_UNCOMMITTED, true); + ConnectorMetadata metadata = connector.getMetadata(transactionHandle); List databases = metadata.listSchemaNames(SESSION); assertTrue(databases.contains(database.toLowerCase(ROOT))); } @@ -156,6 +161,8 @@ public void testGetDatabaseNames() @Test public void testGetTableNames() { + ConnectorTransactionHandle transactionHandle = connector.beginTransaction(READ_UNCOMMITTED, true); + ConnectorMetadata metadata = connector.getMetadata(transactionHandle); List tables = metadata.listTables(SESSION, database); assertTrue(tables.contains(table)); } @@ -164,12 +171,16 @@ public void testGetTableNames() @Test(enabled = false, expectedExceptions = SchemaNotFoundException.class) public void testGetTableNamesException() { + ConnectorTransactionHandle transactionHandle = connector.beginTransaction(READ_UNCOMMITTED, true); + ConnectorMetadata metadata = connector.getMetadata(transactionHandle); metadata.listTables(SESSION, INVALID_DATABASE); } @Test public void testListUnknownSchema() { + ConnectorTransactionHandle transactionHandle = connector.beginTransaction(READ_UNCOMMITTED, true); + ConnectorMetadata metadata = connector.getMetadata(transactionHandle); assertNull(metadata.getTableHandle(SESSION, new SchemaTableName("totally_invalid_database_name", "dual"))); assertEquals(metadata.listTables(SESSION, "totally_invalid_database_name"), ImmutableList.of()); assertEquals(metadata.listTableColumns(SESSION, new SchemaTablePrefix("totally_invalid_database_name", "dual")), ImmutableMap.of()); @@ -178,23 +189,23 @@ public void testListUnknownSchema() @Test public void testGetRecords() { - ConnectorTableHandle tableHandle = getTableHandle(table); + ConnectorTransactionHandle transactionHandle = connector.beginTransaction(READ_UNCOMMITTED, true); + ConnectorMetadata metadata = connector.getMetadata(transactionHandle); + ConnectorTableHandle tableHandle = getTableHandle(table, metadata); ConnectorTableMetadata tableMetadata = metadata.getTableMetadata(SESSION, tableHandle); List columnHandles = ImmutableList.copyOf(metadata.getColumnHandles(SESSION, tableHandle).values()); Map columnIndex = indexColumns(columnHandles); - ConnectorTransactionHandle transaction = CassandraTransactionHandle.INSTANCE; - ConnectorTableLayoutResult layoutResult = metadata.getTableLayoutForConstraint(SESSION, tableHandle, Constraint.alwaysTrue(), Optional.empty()); ConnectorTableLayoutHandle layout = layoutResult.getTableLayout().getHandle(); - List splits = getAllSplits(splitManager.getSplits(transaction, SESSION, layout, new SplitSchedulingContext(UNGROUPED_SCHEDULING, false, WarningCollector.NOOP))); + List splits = getAllSplits(splitManager.getSplits(transactionHandle, SESSION, layout, new SplitSchedulingContext(UNGROUPED_SCHEDULING, false, WarningCollector.NOOP))); long rowNumber = 0; for (ConnectorSplit split : splits) { CassandraSplit cassandraSplit = (CassandraSplit) split; long completedBytes = 0; - try (RecordCursor cursor = recordSetProvider.getRecordSet(transaction, SESSION, cassandraSplit, columnHandles).cursor()) { + try (RecordCursor cursor = recordSetProvider.getRecordSet(transactionHandle, SESSION, cassandraSplit, columnHandles).cursor()) { while (cursor.advanceNextPosition()) { try { assertReadFields(cursor, tableMetadata.getColumns()); @@ -231,6 +242,39 @@ public void testGetRecords() assertEquals(rowNumber, 9); } + @Test + public void testRollbackTables() + { + ConnectorTableMetadata connectorTableMetadata = new ConnectorTableMetadata( + rollbackTable, + ImmutableList.of( + ColumnMetadata.builder() + .setName("test_col") + .setType(BIGINT) + .build())); + + // start a transaction + ConnectorTransactionHandle transactionHandle = connector.beginTransaction(READ_UNCOMMITTED, true); + ConnectorMetadata metadata = connector.getMetadata(transactionHandle); + ConnectorOutputTableHandle handle = null; + + try { + // Begin table creation (STAGING only) + handle = metadata.beginCreateTable(SESSION, connectorTableMetadata, Optional.empty()); + // simulate a failure + throw new RuntimeException("Force failure before finish"); + } + catch (RuntimeException e) { + if (handle != null) { + // table should exist + assertTrue(metadata.listTables(SESSION, database).contains(rollbackTable)); + // rollback table + connector.rollback(transactionHandle); + } + } + assertFalse(metadata.listTables(SESSION, database).contains(rollbackTable)); + } + private static void assertReadFields(RecordCursor cursor, List schema) { for (int columnIndex = 0; columnIndex < schema.size(); columnIndex++) { @@ -270,7 +314,7 @@ else if (isVarcharType(type) || VARBINARY.equals(type)) { } } - private ConnectorTableHandle getTableHandle(SchemaTableName tableName) + private ConnectorTableHandle getTableHandle(SchemaTableName tableName, ConnectorMetadata metadata) { ConnectorTableHandle handle = metadata.getTableHandle(SESSION, tableName); checkArgument(handle != null, "table not found: %s", tableName);