diff --git a/airbyte-config/config-persistence/src/main/java/io/airbyte/config/persistence/StatePersistence.java b/airbyte-config/config-persistence/src/main/java/io/airbyte/config/persistence/StatePersistence.java new file mode 100644 index 000000000000..a23d1f0c4e0f --- /dev/null +++ b/airbyte-config/config-persistence/src/main/java/io/airbyte/config/persistence/StatePersistence.java @@ -0,0 +1,323 @@ +/* + * Copyright (c) 2022 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.config.persistence; + +import static io.airbyte.db.instance.configs.jooq.generated.Tables.STATE; + +import com.fasterxml.jackson.databind.JsonNode; +import io.airbyte.commons.enums.Enums; +import io.airbyte.commons.json.Jsons; +import io.airbyte.config.StateType; +import io.airbyte.config.StateWrapper; +import io.airbyte.db.Database; +import io.airbyte.db.ExceptionWrappingDatabase; +import io.airbyte.protocol.models.AirbyteGlobalState; +import io.airbyte.protocol.models.AirbyteStateMessage; +import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType; +import io.airbyte.protocol.models.AirbyteStreamState; +import io.airbyte.protocol.models.StreamDescriptor; +import java.io.IOException; +import java.time.OffsetDateTime; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import org.jooq.Condition; +import org.jooq.DSLContext; +import org.jooq.Field; +import org.jooq.JSONB; +import org.jooq.Record; +import org.jooq.RecordMapper; +import org.jooq.impl.DSL; + +/** + * State Persistence + * + * Handle persisting States to the Database. + * + * Supports migration from Legacy to Global or Stream. Other type migrations need to go through a + * reset. (an exception will be thrown) + */ +public class StatePersistence { + + private final ExceptionWrappingDatabase database; + + public StatePersistence(final Database database) { + this.database = new ExceptionWrappingDatabase(database); + } + + /** + * Get the current State of a Connection + * + * @param connectionId + * @return + * @throws IOException + */ + public Optional getCurrentState(final UUID connectionId) throws IOException { + final List records = this.database.query(ctx -> getStateRecords(ctx, connectionId)); + + if (records.isEmpty()) { + return Optional.empty(); + } + + return switch (getStateType(connectionId, records)) { + case GLOBAL -> Optional.of(buildGlobalState(records)); + case STREAM -> Optional.of(buildStreamState(records)); + default -> Optional.of(buildLegacyState(records)); + }; + } + + /** + * Create or update the states described in the StateWrapper. Null states will be deleted. + * + * The only state migrations supported are going from a Legacy state to either a Global or Stream + * state. Other state type migrations should go through an explicit reset. An exception will be + * thrown to prevent the system from getting into a bad state. + * + * @param connectionId + * @param state + * @throws IOException + */ + public void updateOrCreateState(final UUID connectionId, final StateWrapper state) throws IOException { + final Optional previousState = getCurrentState(connectionId); + final boolean isMigration = previousState.isPresent() && previousState.get().getStateType() == StateType.LEGACY && + state.getStateType() != StateType.LEGACY; + + // The only case where we allow a state migration is moving from LEGACY. + // We expect any other migration to go through an explicit reset. + if (!isMigration && previousState.isPresent() && previousState.get().getStateType() != state.getStateType()) { + throw new IllegalStateException("Unexpected type migration from '" + previousState.get().getStateType() + "' to '" + state.getStateType() + + "'. Migration of StateType need to go through an explicit reset."); + } + + this.database.transaction(ctx -> { + if (isMigration) { + clearLegacyState(ctx, connectionId); + } + switch (state.getStateType()) { + case GLOBAL -> saveGlobalState(ctx, connectionId, state.getGlobal().getGlobal()); + case STREAM -> saveStreamState(ctx, connectionId, state.getStateMessages()); + case LEGACY -> saveLegacyState(ctx, connectionId, state.getLegacyState()); + } + return null; + }); + } + + private static void clearLegacyState(final DSLContext ctx, final UUID connectionId) { + writeStateToDb(ctx, connectionId, null, null, StateType.LEGACY, null); + } + + private static void saveGlobalState(final DSLContext ctx, final UUID connectionId, final AirbyteGlobalState globalState) { + writeStateToDb(ctx, connectionId, null, null, StateType.GLOBAL, globalState.getSharedState()); + for (final AirbyteStreamState streamState : globalState.getStreamStates()) { + writeStateToDb(ctx, + connectionId, + streamState.getStreamDescriptor().getName(), + streamState.getStreamDescriptor().getNamespace(), + StateType.GLOBAL, + streamState.getStreamState()); + } + } + + private static void saveStreamState(final DSLContext ctx, final UUID connectionId, final List stateMessages) { + for (final AirbyteStateMessage stateMessage : stateMessages) { + final AirbyteStreamState streamState = stateMessage.getStream(); + writeStateToDb(ctx, + connectionId, + streamState.getStreamDescriptor().getName(), + streamState.getStreamDescriptor().getNamespace(), + StateType.STREAM, + streamState.getStreamState()); + } + } + + private static void saveLegacyState(final DSLContext ctx, final UUID connectionId, final JsonNode state) { + writeStateToDb(ctx, connectionId, null, null, StateType.LEGACY, state); + } + + /** + * Performs the actual SQL operation depending on the state + * + * If the state is null, it will delete the row, otherwise do an insert or update on conflict + */ + static void writeStateToDb(final DSLContext ctx, + final UUID connectionId, + final String streamName, + final String namespace, + final StateType stateType, + final JsonNode state) { + if (state != null) { + final boolean hasState = ctx.selectFrom(STATE) + .where( + STATE.CONNECTION_ID.eq(connectionId), + isNullOrEquals(STATE.STREAM_NAME, streamName), + isNullOrEquals(STATE.NAMESPACE, namespace)) + .fetch().isNotEmpty(); + + final JSONB jsonbState = JSONB.valueOf(Jsons.serialize(state)); + final OffsetDateTime now = OffsetDateTime.now(); + + if (!hasState) { + ctx.insertInto(STATE) + .columns( + STATE.ID, + STATE.CREATED_AT, + STATE.UPDATED_AT, + STATE.CONNECTION_ID, + STATE.STREAM_NAME, + STATE.NAMESPACE, + STATE.STATE_, + STATE.TYPE) + .values( + UUID.randomUUID(), + now, + now, + connectionId, + streamName, + namespace, + jsonbState, + Enums.convertTo(stateType, io.airbyte.db.instance.configs.jooq.generated.enums.StateType.class)) + .execute(); + + } else { + ctx.update(STATE) + .set(STATE.UPDATED_AT, now) + .set(STATE.STATE_, jsonbState) + .where( + STATE.CONNECTION_ID.eq(connectionId), + isNullOrEquals(STATE.STREAM_NAME, streamName), + isNullOrEquals(STATE.NAMESPACE, namespace)) + .execute(); + } + + } else { + // If the state is null, we remove the state instead of keeping a null row + ctx.deleteFrom(STATE) + .where( + STATE.CONNECTION_ID.eq(connectionId), + isNullOrEquals(STATE.STREAM_NAME, streamName), + isNullOrEquals(STATE.NAMESPACE, namespace)) + .execute(); + } + } + + /** + * Helper function to handle null or equal case for the optional strings + * + * We need to have an explicit check for null values because NULL != "str" is NULL, not a boolean. + * + * @param field the targeted field + * @param value the value to check + * @return The Condition that performs the desired check + */ + private static Condition isNullOrEquals(final Field field, final String value) { + return value != null ? field.eq(value) : field.isNull(); + } + + /** + * Get the StateType for a given list of StateRecords + * + * @param connectionId The connectionId of the records, used to add more debugging context if an + * error is detected + * @param records The list of StateRecords to process, must not be empty + * @return the StateType of the records + * @throws IllegalStateException If StateRecords have inconsistent types + */ + private static io.airbyte.db.instance.configs.jooq.generated.enums.StateType getStateType( + final UUID connectionId, + final List records) { + final Set types = + records.stream().map(r -> r.type).collect(Collectors.toSet()); + if (types.size() == 1) { + return types.stream().findFirst().get(); + } + + throw new IllegalStateException("Inconsistent StateTypes for connectionId " + connectionId + + " (" + String.join(", ", types.stream().map(stateType -> stateType.getLiteral()).toList()) + ")"); + } + + /** + * Get the state records from the DB + * + * @param ctx A valid DSL context to use for the query + * @param connectionId the ID of the connection + * @return The StateRecords for the connectionId + */ + private static List getStateRecords(final DSLContext ctx, final UUID connectionId) { + return ctx.select(DSL.asterisk()) + .from(STATE) + .where(STATE.CONNECTION_ID.eq(connectionId)) + .fetch(getStateRecordMapper()) + .stream().toList(); + } + + /** + * Build Global state + * + * The list of records can contain one global shared state that is the state without streamName and + * without namespace The other records should be translated into AirbyteStreamState + */ + private static StateWrapper buildGlobalState(final List records) { + // Split the global shared state from the other per stream records + final Map> partitions = records.stream() + .collect(Collectors.partitioningBy(r -> r.streamName == null && r.namespace == null)); + + final AirbyteGlobalState globalState = new AirbyteGlobalState() + .withSharedState(partitions.get(Boolean.TRUE).stream().map(r -> r.state).findFirst().orElse(null)) + .withStreamStates(partitions.get(Boolean.FALSE).stream().map(StatePersistence::buildAirbyteStreamState).toList()); + + final AirbyteStateMessage msg = new AirbyteStateMessage() + .withType(AirbyteStateType.GLOBAL) + .withGlobal(globalState); + return new StateWrapper().withStateType(StateType.GLOBAL).withGlobal(msg); + } + + /** + * Build StateWrapper for a PerStream state + */ + private static StateWrapper buildStreamState(final List records) { + final List messages = records.stream().map( + record -> new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(buildAirbyteStreamState(record))) + .toList(); + return new StateWrapper().withStateType(StateType.STREAM).withStateMessages(messages); + } + + /** + * Build a StateWrapper for Legacy state + */ + private static StateWrapper buildLegacyState(final List records) { + return new StateWrapper() + .withStateType(StateType.LEGACY) + .withLegacyState(records.get(0).state); + } + + /** + * Convert a StateRecord to an AirbyteStreamState + */ + private static AirbyteStreamState buildAirbyteStreamState(final StateRecord record) { + return new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName(record.streamName).withNamespace(record.namespace)) + .withStreamState(record.state); + } + + private static RecordMapper getStateRecordMapper() { + return record -> new StateRecord( + record.get(STATE.TYPE, io.airbyte.db.instance.configs.jooq.generated.enums.StateType.class), + record.get(STATE.STREAM_NAME, String.class), + record.get(STATE.NAMESPACE, String.class), + Jsons.deserialize(record.get(STATE.STATE_).data())); + } + + private record StateRecord( + io.airbyte.db.instance.configs.jooq.generated.enums.StateType type, + String streamName, + String namespace, + JsonNode state) {} + +} diff --git a/airbyte-config/config-persistence/src/test/java/io/airbyte/config/persistence/StatePersistenceTest.java b/airbyte-config/config-persistence/src/test/java/io/airbyte/config/persistence/StatePersistenceTest.java new file mode 100644 index 000000000000..0c4e70dcf522 --- /dev/null +++ b/airbyte-config/config-persistence/src/test/java/io/airbyte/config/persistence/StatePersistenceTest.java @@ -0,0 +1,555 @@ +/* + * Copyright (c) 2022 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.config.persistence; + +import static org.mockito.Mockito.mock; + +import com.fasterxml.jackson.databind.JsonNode; +import io.airbyte.commons.enums.Enums; +import io.airbyte.commons.json.Jsons; +import io.airbyte.config.DestinationConnection; +import io.airbyte.config.SourceConnection; +import io.airbyte.config.StandardDestinationDefinition; +import io.airbyte.config.StandardSourceDefinition; +import io.airbyte.config.StandardSync; +import io.airbyte.config.StandardWorkspace; +import io.airbyte.config.StateType; +import io.airbyte.config.StateWrapper; +import io.airbyte.config.persistence.split_secrets.JsonSecretsProcessor; +import io.airbyte.db.factory.DSLContextFactory; +import io.airbyte.db.factory.FlywayFactory; +import io.airbyte.db.init.DatabaseInitializationException; +import io.airbyte.db.instance.configs.ConfigsDatabaseMigrator; +import io.airbyte.db.instance.configs.ConfigsDatabaseTestProvider; +import io.airbyte.protocol.models.AirbyteGlobalState; +import io.airbyte.protocol.models.AirbyteStateMessage; +import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType; +import io.airbyte.protocol.models.AirbyteStreamState; +import io.airbyte.protocol.models.StreamDescriptor; +import io.airbyte.test.utils.DatabaseConnectionHelper; +import io.airbyte.validation.json.JsonValidationException; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Optional; +import java.util.UUID; +import org.jooq.JSONB; +import org.jooq.SQLDialect; +import org.jooq.impl.DSL; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class StatePersistenceTest extends BaseDatabaseConfigPersistenceTest { + + private StatePersistence statePersistence; + private UUID connectionId; + + @Test + public void testReadingNonExistingState() throws IOException { + Assertions.assertTrue(statePersistence.getCurrentState(UUID.randomUUID()).isEmpty()); + } + + @Test + public void testLegacyReadWrite() throws IOException { + final StateWrapper state0 = new StateWrapper() + .withStateType(StateType.LEGACY) + .withLegacyState(Jsons.deserialize("{\"woot\": \"legacy states is passthrough\"}")); + + // Initial write/read loop, making sure we read what we wrote + statePersistence.updateOrCreateState(connectionId, state0); + final Optional state1 = statePersistence.getCurrentState(connectionId); + + Assertions.assertTrue(state1.isPresent()); + Assertions.assertEquals(StateType.LEGACY, state1.get().getStateType()); + Assertions.assertEquals(state0.getLegacyState(), state1.get().getLegacyState()); + + // Updating a state + final JsonNode newStateJson = Jsons.deserialize("{\"woot\": \"new state\"}"); + final StateWrapper state2 = clone(state1.get()).withLegacyState(newStateJson); + statePersistence.updateOrCreateState(connectionId, state2); + final Optional state3 = statePersistence.getCurrentState(connectionId); + + Assertions.assertTrue(state3.isPresent()); + Assertions.assertEquals(StateType.LEGACY, state3.get().getStateType()); + Assertions.assertEquals(newStateJson, state3.get().getLegacyState()); + + // Deleting a state + final StateWrapper state4 = clone(state3.get()).withLegacyState(null); + statePersistence.updateOrCreateState(connectionId, state4); + Assertions.assertTrue(statePersistence.getCurrentState(connectionId).isEmpty()); + } + + @Test + public void testLegacyMigrationToGlobal() throws IOException { + final StateWrapper state0 = new StateWrapper() + .withStateType(StateType.LEGACY) + .withLegacyState(Jsons.deserialize("{\"woot\": \"legacy states is passthrough\"}")); + + statePersistence.updateOrCreateState(connectionId, state0); + + final StateWrapper newGlobalState = new StateWrapper() + .withStateType(StateType.GLOBAL) + .withGlobal(new AirbyteStateMessage() + .withType(AirbyteStateType.GLOBAL) + .withGlobal(new AirbyteGlobalState() + .withSharedState(Jsons.deserialize("\"woot\"")) + .withStreamStates(Arrays.asList( + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n2")) + .withStreamState(Jsons.deserialize("\"state1\"")), + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1")) + .withStreamState(Jsons.deserialize("\"state2\"")))))); + statePersistence.updateOrCreateState(connectionId, newGlobalState); + final StateWrapper storedGlobalState = statePersistence.getCurrentState(connectionId).orElseThrow(); + assertEquals(newGlobalState, storedGlobalState); + } + + @Test + public void testLegacyMigrationToStream() throws IOException { + final StateWrapper state0 = new StateWrapper() + .withStateType(StateType.LEGACY) + .withLegacyState(Jsons.deserialize("{\"woot\": \"legacy states is passthrough\"}")); + + statePersistence.updateOrCreateState(connectionId, state0); + + final StateWrapper newStreamState = new StateWrapper() + .withStateType(StateType.STREAM) + .withStateMessages(Arrays.asList( + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n1")) + .withStreamState(Jsons.deserialize("\"state s1.n1\""))), + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s2")) + .withStreamState(Jsons.deserialize("\"state s2\""))))); + statePersistence.updateOrCreateState(connectionId, newStreamState); + final StateWrapper storedStreamState = statePersistence.getCurrentState(connectionId).orElseThrow(); + assertEquals(newStreamState, storedStreamState); + } + + @Test + public void testGlobalReadWrite() throws IOException { + final StateWrapper state0 = new StateWrapper() + .withStateType(StateType.GLOBAL) + .withGlobal(new AirbyteStateMessage() + .withType(AirbyteStateType.GLOBAL) + .withGlobal(new AirbyteGlobalState() + .withSharedState(Jsons.deserialize("\"my global state\"")) + .withStreamStates(Arrays.asList( + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n2")) + .withStreamState(Jsons.deserialize("\"state1\"")), + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1")) + .withStreamState(Jsons.deserialize("\"state2\"")))))); + + // Initial write/read loop, making sure we read what we wrote + statePersistence.updateOrCreateState(connectionId, state0); + final Optional state1 = statePersistence.getCurrentState(connectionId); + Assertions.assertTrue(state1.isPresent()); + assertEquals(state0, state1.get()); + + // Updating a state + final StateWrapper state2 = clone(state1.get()); + state2.getGlobal() + .getGlobal().withSharedState(Jsons.deserialize("\"updated shared state\"")) + .getStreamStates().get(1).withStreamState(Jsons.deserialize("\"updated state2\"")); + statePersistence.updateOrCreateState(connectionId, state2); + final Optional state3 = statePersistence.getCurrentState(connectionId); + + Assertions.assertTrue(state3.isPresent()); + assertEquals(state2, state3.get()); + + // Updating a state with name and namespace + final StateWrapper state4 = clone(state1.get()); + state4.getGlobal().getGlobal() + .getStreamStates().get(0).withStreamState(Jsons.deserialize("\"updated state1\"")); + statePersistence.updateOrCreateState(connectionId, state4); + final Optional state5 = statePersistence.getCurrentState(connectionId); + + Assertions.assertTrue(state5.isPresent()); + assertEquals(state4, state5.get()); + } + + @Test + public void testGlobalPartialReset() throws IOException { + final StateWrapper state0 = new StateWrapper() + .withStateType(StateType.GLOBAL) + .withGlobal(new AirbyteStateMessage() + .withType(AirbyteStateType.GLOBAL) + .withGlobal(new AirbyteGlobalState() + .withSharedState(Jsons.deserialize("\"my global state\"")) + .withStreamStates(Arrays.asList( + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n2")) + .withStreamState(Jsons.deserialize("\"state1\"")), + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1")) + .withStreamState(Jsons.deserialize("\"state2\"")))))); + + // Set the initial state + statePersistence.updateOrCreateState(connectionId, state0); + + // incomplete reset does not remove the state + final StateWrapper incompletePartialReset = new StateWrapper() + .withStateType(StateType.GLOBAL) + .withGlobal(new AirbyteStateMessage() + .withType(AirbyteStateType.GLOBAL) + .withGlobal(new AirbyteGlobalState() + .withSharedState(Jsons.deserialize("\"my global state\"")) + .withStreamStates(Arrays.asList( + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1")) + .withStreamState(Jsons.deserialize("\"state2\"")))))); + statePersistence.updateOrCreateState(connectionId, incompletePartialReset); + final StateWrapper incompletePartialResetResult = statePersistence.getCurrentState(connectionId).orElseThrow(); + Assertions.assertEquals(state0, incompletePartialResetResult); + + // The good partial reset + final StateWrapper partialReset = new StateWrapper() + .withStateType(StateType.GLOBAL) + .withGlobal(new AirbyteStateMessage() + .withType(AirbyteStateType.GLOBAL) + .withGlobal(new AirbyteGlobalState() + .withSharedState(Jsons.deserialize("\"my global state\"")) + .withStreamStates(Arrays.asList( + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n2")) + .withStreamState(Jsons.deserialize("\"state1\"")), + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1")) + .withStreamState(null))))); + statePersistence.updateOrCreateState(connectionId, partialReset); + final StateWrapper partialResetResult = statePersistence.getCurrentState(connectionId).orElseThrow(); + + Assertions.assertEquals(partialReset.getGlobal().getGlobal().getSharedState(), + partialResetResult.getGlobal().getGlobal().getSharedState()); + // {"name": "s1"} should have been removed from the stream states + Assertions.assertEquals(1, partialResetResult.getGlobal().getGlobal().getStreamStates().size()); + Assertions.assertEquals(partialReset.getGlobal().getGlobal().getStreamStates().get(0), + partialResetResult.getGlobal().getGlobal().getStreamStates().get(0)); + } + + @Test + public void testGlobalFullReset() throws IOException { + final StateWrapper state0 = new StateWrapper() + .withStateType(StateType.GLOBAL) + .withGlobal(new AirbyteStateMessage() + .withType(AirbyteStateType.GLOBAL) + .withGlobal(new AirbyteGlobalState() + .withSharedState(Jsons.deserialize("\"my global state\"")) + .withStreamStates(Arrays.asList( + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n2")) + .withStreamState(Jsons.deserialize("\"state1\"")), + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1")) + .withStreamState(Jsons.deserialize("\"state2\"")))))); + + final StateWrapper fullReset = new StateWrapper() + .withStateType(StateType.GLOBAL) + .withGlobal(new AirbyteStateMessage() + .withType(AirbyteStateType.GLOBAL) + .withGlobal(new AirbyteGlobalState() + .withSharedState(null) + .withStreamStates(Arrays.asList( + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n2")) + .withStreamState(null), + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1")) + .withStreamState(null)))));; + + statePersistence.updateOrCreateState(connectionId, state0); + statePersistence.updateOrCreateState(connectionId, fullReset); + final Optional fullResetResult = statePersistence.getCurrentState(connectionId); + Assertions.assertTrue(fullResetResult.isEmpty()); + } + + @Test + public void testGlobalStateAllowsEmptyNameAndNamespace() throws IOException { + final StateWrapper state0 = new StateWrapper() + .withStateType(StateType.GLOBAL) + .withGlobal(new AirbyteStateMessage() + .withType(AirbyteStateType.GLOBAL) + .withGlobal(new AirbyteGlobalState() + .withSharedState(Jsons.deserialize("\"my global state\"")) + .withStreamStates(Arrays.asList( + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("")) + .withStreamState(Jsons.deserialize("\"empty name state\"")), + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("").withNamespace("")) + .withStreamState(Jsons.deserialize("\"empty name and namespace state\"")))))); + + statePersistence.updateOrCreateState(connectionId, state0); + final StateWrapper state1 = statePersistence.getCurrentState(connectionId).orElseThrow(); + assertEquals(state0, state1); + } + + @Test + public void testStreamReadWrite() throws IOException { + final StateWrapper state0 = new StateWrapper() + .withStateType(StateType.STREAM) + .withStateMessages(Arrays.asList( + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n1")) + .withStreamState(Jsons.deserialize("\"state s1.n1\""))), + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s2")) + .withStreamState(Jsons.deserialize("\"state s2\""))))); + + // Initial write/read loop, making sure we read what we wrote + statePersistence.updateOrCreateState(connectionId, state0); + final StateWrapper state1 = statePersistence.getCurrentState(connectionId).orElseThrow(); + assertEquals(state0, state1); + + // Updating a state + final StateWrapper state2 = clone(state1); + state2.getStateMessages().get(1).getStream().withStreamState(Jsons.deserialize("\"updated state s2\"")); + statePersistence.updateOrCreateState(connectionId, state2); + final StateWrapper state3 = statePersistence.getCurrentState(connectionId).orElseThrow(); + assertEquals(state2, state3); + + // Updating a state with name and namespace + final StateWrapper state4 = clone(state1); + state4.getStateMessages().get(0).getStream().withStreamState(Jsons.deserialize("\"updated state s1\"")); + statePersistence.updateOrCreateState(connectionId, state4); + final StateWrapper state5 = statePersistence.getCurrentState(connectionId).orElseThrow(); + assertEquals(state4, state5); + } + + @Test + public void testStreamPartialUpdates() throws IOException { + final StateWrapper state0 = new StateWrapper() + .withStateType(StateType.STREAM) + .withStateMessages(Arrays.asList( + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n1")) + .withStreamState(Jsons.deserialize("\"state s1.n1\""))), + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s2")) + .withStreamState(Jsons.deserialize("\"state s2\""))))); + + statePersistence.updateOrCreateState(connectionId, state0); + + // Partial update + final StateWrapper partialUpdate = new StateWrapper() + .withStateType(StateType.STREAM) + .withStateMessages(Collections.singletonList( + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n1")) + .withStreamState(Jsons.deserialize("\"updated\""))))); + statePersistence.updateOrCreateState(connectionId, partialUpdate); + final StateWrapper partialUpdateResult = statePersistence.getCurrentState(connectionId).orElseThrow(); + assertEquals( + new StateWrapper() + .withStateType(StateType.STREAM) + .withStateMessages(Arrays.asList( + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n1")) + .withStreamState(Jsons.deserialize("\"updated\""))), + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s2")) + .withStreamState(Jsons.deserialize("\"state s2\""))))), + partialUpdateResult); + + // Partial Reset + final StateWrapper partialReset = new StateWrapper() + .withStateType(StateType.STREAM) + .withStateMessages(Collections.singletonList( + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s2")) + .withStreamState(null)))); + statePersistence.updateOrCreateState(connectionId, partialReset); + final StateWrapper partialResetResult = statePersistence.getCurrentState(connectionId).orElseThrow(); + assertEquals( + new StateWrapper() + .withStateType(StateType.STREAM) + .withStateMessages(Arrays.asList( + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n1")) + .withStreamState(Jsons.deserialize("\"updated\""))))), + partialResetResult); + } + + @Test + public void testStreamFullReset() throws IOException { + final StateWrapper state0 = new StateWrapper() + .withStateType(StateType.STREAM) + .withStateMessages(Arrays.asList( + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n1")) + .withStreamState(Jsons.deserialize("\"state s1.n1\""))), + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s2")) + .withStreamState(Jsons.deserialize("\"state s2\""))))); + + statePersistence.updateOrCreateState(connectionId, state0); + + // Partial update + final StateWrapper fullReset = new StateWrapper() + .withStateType(StateType.STREAM) + .withStateMessages(Arrays.asList( + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n1")) + .withStreamState(null)), + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s2")) + .withStreamState(null)))); + statePersistence.updateOrCreateState(connectionId, fullReset); + final Optional fullResetResult = statePersistence.getCurrentState(connectionId); + Assertions.assertTrue(fullResetResult.isEmpty()); + } + + @Test + public void testInconsistentTypeUpdates() throws IOException { + final StateWrapper streamState = new StateWrapper() + .withStateType(StateType.STREAM) + .withStateMessages(Arrays.asList( + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s1").withNamespace("n1")) + .withStreamState(Jsons.deserialize("\"state s1.n1\""))), + new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream(new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("s2")) + .withStreamState(Jsons.deserialize("\"state s2\""))))); + statePersistence.updateOrCreateState(connectionId, streamState); + + Assertions.assertThrows(IllegalStateException.class, () -> { + final StateWrapper globalState = new StateWrapper() + .withStateType(StateType.GLOBAL) + .withGlobal(new AirbyteStateMessage() + .withType(AirbyteStateType.GLOBAL) + .withGlobal(new AirbyteGlobalState() + .withSharedState(Jsons.deserialize("\"my global state\"")) + .withStreamStates(Arrays.asList( + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("")) + .withStreamState(Jsons.deserialize("\"empty name state\"")), + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor().withName("").withNamespace("")) + .withStreamState(Jsons.deserialize("\"empty name and namespace state\"")))))); + statePersistence.updateOrCreateState(connectionId, globalState); + }); + + // We should be guarded against those cases let's make sure we don't make things worse if we're in + // an inconsistent state + dslContext.insertInto(DSL.table("state")) + .columns(DSL.field("id"), DSL.field("connection_id"), DSL.field("type"), DSL.field("state")) + .values(UUID.randomUUID(), connectionId, io.airbyte.db.instance.configs.jooq.generated.enums.StateType.GLOBAL, JSONB.valueOf("{}")) + .execute(); + Assertions.assertThrows(IllegalStateException.class, () -> statePersistence.updateOrCreateState(connectionId, streamState)); + Assertions.assertThrows(IllegalStateException.class, () -> statePersistence.getCurrentState(connectionId)); + } + + @Test + public void testEnumsConversion() { + // Making sure StateType we write to the DB and the StateType from the protocols are aligned. + // Otherwise, we'll have to dig through runtime errors. + Assertions.assertTrue(Enums.isCompatible( + io.airbyte.db.instance.configs.jooq.generated.enums.StateType.class, + io.airbyte.config.StateType.class)); + } + + @BeforeEach + public void beforeEach() throws DatabaseInitializationException, IOException, JsonValidationException { + dataSource = DatabaseConnectionHelper.createDataSource(container); + dslContext = DSLContextFactory.create(dataSource, SQLDialect.POSTGRES); + flyway = FlywayFactory.create(dataSource, DatabaseConfigPersistenceLoadDataTest.class.getName(), + ConfigsDatabaseMigrator.DB_IDENTIFIER, ConfigsDatabaseMigrator.MIGRATION_FILE_LOCATION); + database = new ConfigsDatabaseTestProvider(dslContext, flyway).create(true); + setupTestData(); + + statePersistence = new StatePersistence(database); + } + + @AfterEach + public void afterEach() { + // Making sure we reset between tests + dslContext.dropSchemaIfExists("public").cascade().execute(); + dslContext.createSchema("public").execute(); + dslContext.setSchema("public").execute(); + } + + private void setupTestData() throws JsonValidationException, IOException { + ConfigRepository configRepository = new ConfigRepository( + new DatabaseConfigPersistence(database, mock(JsonSecretsProcessor.class)), + database); + + final StandardWorkspace workspace = MockData.standardWorkspaces().get(0); + final StandardSourceDefinition sourceDefinition = MockData.publicSourceDefinition(); + final SourceConnection sourceConnection = MockData.sourceConnections().get(0); + final StandardDestinationDefinition destinationDefinition = MockData.publicDestinationDefinition(); + final DestinationConnection destinationConnection = MockData.destinationConnections().get(0); + final StandardSync sync = MockData.standardSyncs().get(0); + + configRepository.writeStandardWorkspace(workspace); + configRepository.writeStandardSourceDefinition(sourceDefinition); + configRepository.writeSourceConnectionNoSecrets(sourceConnection); + configRepository.writeStandardDestinationDefinition(destinationDefinition); + configRepository.writeDestinationConnectionNoSecrets(destinationConnection); + configRepository.writeStandardSyncOperation(MockData.standardSyncOperations().get(0)); + configRepository.writeStandardSyncOperation(MockData.standardSyncOperations().get(1)); + configRepository.writeStandardSync(sync); + + connectionId = sync.getConnectionId(); + } + + private StateWrapper clone(final StateWrapper state) { + return switch (state.getStateType()) { + case LEGACY -> new StateWrapper() + .withLegacyState(Jsons.deserialize(Jsons.serialize(state.getLegacyState()))) + .withStateType(state.getStateType()); + case STREAM -> new StateWrapper() + .withStateMessages( + state.getStateMessages().stream().map(msg -> Jsons.deserialize(Jsons.serialize(msg), AirbyteStateMessage.class)).toList()) + .withStateType(state.getStateType()); + case GLOBAL -> new StateWrapper() + .withGlobal(Jsons.deserialize(Jsons.serialize(state.getGlobal()), AirbyteStateMessage.class)) + .withStateType(state.getStateType()); + }; + } + + private void assertEquals(StateWrapper lhs, StateWrapper rhs) { + Assertions.assertEquals(Jsons.serialize(lhs), Jsons.serialize(rhs)); + } + +}