diff --git a/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/db/DbInitHelper.kt b/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/db/DbInitHelper.kt index 231c985f2..a2d2fb5bd 100644 --- a/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/db/DbInitHelper.kt +++ b/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/db/DbInitHelper.kt @@ -36,14 +36,15 @@ import fr.acinq.phoenix.db.sqldelight.* import fr.acinq.phoenix.managers.PaymentMetadataQueue import fr.acinq.phoenix.utils.extensions.toByteArray -fun createSqliteChannelsDb(driver: SqlDriver): SqliteChannelsDb { +fun createSqliteChannelsDb(driver: SqlDriver, loggerFactory: LoggerFactory): SqliteChannelsDb { return SqliteChannelsDb( driver = driver, database = ChannelsDatabase( driver = driver, htlc_infosAdapter = Htlc_infos.Adapter(ByteVector32Adapter, ByteVector32Adapter), - local_channelsAdapter = Local_channels.Adapter(ByteVector32Adapter, PersistedChannelStateAdapter) - ) + local_channelsAdapter = Local_channels.Adapter(ByteVector32Adapter), + ), + loggerFactory = loggerFactory ) } diff --git a/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/db/SqliteChannelsDb.kt b/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/db/SqliteChannelsDb.kt index 1b6430e62..567530568 100644 --- a/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/db/SqliteChannelsDb.kt +++ b/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/db/SqliteChannelsDb.kt @@ -21,22 +21,24 @@ import fr.acinq.bitcoin.ByteVector32 import fr.acinq.lightning.CltvExpiry import fr.acinq.lightning.channel.states.PersistedChannelState import fr.acinq.lightning.db.ChannelsDb +import fr.acinq.lightning.logging.LoggerFactory import fr.acinq.phoenix.db.sqldelight.ChannelsDatabase import kotlin.collections.List import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -class SqliteChannelsDb(val driver: SqlDriver, database: ChannelsDatabase) : ChannelsDb { +class SqliteChannelsDb(val driver: SqlDriver, database: ChannelsDatabase, loggerFactory: LoggerFactory) : ChannelsDb { + val log = loggerFactory.newLogger(this::class) private val queries = database.channelsDatabaseQueries override suspend fun addOrUpdateChannel(state: PersistedChannelState) { withContext(Dispatchers.Default) { queries.transaction { queries.getChannel(state.channelId).executeAsOneOrNull()?.run { - queries.updateChannel(channel_id = state.channelId, data_ = state) + queries.updateChannel(channel_id = state.channelId, data_ = PersistedChannelStateAdapter.encode(state)) } ?: run { - queries.insertChannel(channel_id = state.channelId, data_ = state) + queries.insertChannel(channel_id = state.channelId, data_ = PersistedChannelStateAdapter.encode(state)) } } } @@ -44,9 +46,11 @@ class SqliteChannelsDb(val driver: SqlDriver, database: ChannelsDatabase) : Chan suspend fun getChannel(channelId: ByteVector32): Triple? { return withContext(Dispatchers.Default) { - queries.getChannel(channelId, mapper = { channelId, data, isClosed -> - Triple(channelId, data, isClosed) - }).executeAsOneOrNull() + queries.getChannel(channelId).executeAsOneOrNull()?.let { (channelId, data, isClosed) -> + mapChannelData(channelId, data)?.let { + Triple(channelId, it, isClosed) + } + } } } @@ -58,7 +62,7 @@ class SqliteChannelsDb(val driver: SqlDriver, database: ChannelsDatabase) : Chan } override suspend fun listLocalChannels(): List = withContext(Dispatchers.Default) { - queries.listLocalChannels().executeAsList() + queries.listLocalChannels().executeAsList().mapNotNull { (channelId, data) -> mapChannelData(channelId, data) } } override suspend fun addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry) { @@ -80,6 +84,15 @@ class SqliteChannelsDb(val driver: SqlDriver, database: ChannelsDatabase) : Chan } } + private fun mapChannelData(channelId: ByteVector32, data: ByteArray): PersistedChannelState? { + return try { + PersistedChannelStateAdapter.decode(data) + } catch (e: Exception) { + log.e(e) { "failed to read channel data for channel=$channelId :" } + null + } + } + override fun close() { driver.close() } diff --git a/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/managers/DatabaseManager.kt b/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/managers/DatabaseManager.kt index 1a135ff8c..7c0ee55e9 100644 --- a/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/managers/DatabaseManager.kt +++ b/phoenix-shared/src/commonMain/kotlin/fr.acinq.phoenix/managers/DatabaseManager.kt @@ -72,7 +72,7 @@ class DatabaseManager( log.debug { "nodeParams available: building databases..." } val channelsDbDriver = createChannelsDbDriver(ctx, channelsDbName(chain, nodeParams.nodeId)) - val channelsDb = createSqliteChannelsDb(channelsDbDriver) + val channelsDb = createSqliteChannelsDb(channelsDbDriver, loggerFactory) val paymentsDbDriver = createPaymentsDbDriver(ctx, paymentsDbName(chain, nodeParams.nodeId)) { log.e { "payments-db migration error: $it" } } val paymentsDb = createSqlitePaymentsDb(paymentsDbDriver, paymentMetadataQueue, loggerFactory) val cloudKitDb = makeCloudKitDb(appDb, paymentsDb) diff --git a/phoenix-shared/src/commonMain/sqldelight/channelsdb/fr/acinq/phoenix/db/sqldelight/ChannelsDatabase.sq b/phoenix-shared/src/commonMain/sqldelight/channelsdb/fr/acinq/phoenix/db/sqldelight/ChannelsDatabase.sq index a90637b22..5884fbb42 100644 --- a/phoenix-shared/src/commonMain/sqldelight/channelsdb/fr/acinq/phoenix/db/sqldelight/ChannelsDatabase.sq +++ b/phoenix-shared/src/commonMain/sqldelight/channelsdb/fr/acinq/phoenix/db/sqldelight/ChannelsDatabase.sq @@ -1,12 +1,11 @@ import fr.acinq.bitcoin.ByteVector32; -import fr.acinq.lightning.channel.states.PersistedChannelState; import kotlin.Boolean; -- channels table -- note: boolean are stored as INTEGER, with 0=false CREATE TABLE local_channels ( channel_id BLOB AS ByteVector32 NOT NULL PRIMARY KEY, - data BLOB AS PersistedChannelState NOT NULL, + data BLOB NOT NULL, is_closed INTEGER AS Boolean DEFAULT 0 NOT NULL ); @@ -35,7 +34,7 @@ closeLocalChannel: UPDATE local_channels SET is_closed=1 WHERE channel_id=?; listLocalChannels: -SELECT data FROM local_channels WHERE is_closed=0; +SELECT channel_id, data FROM local_channels WHERE is_closed=0; -- htlcs info queries insertHtlcInfo: diff --git a/phoenix-shared/src/commonTest/kotlin/fr/acinq/phoenix/db/SqliteChannelsDbTest.kt b/phoenix-shared/src/commonTest/kotlin/fr/acinq/phoenix/db/SqliteChannelsDbTest.kt index acbc9e8a1..cfca31399 100644 --- a/phoenix-shared/src/commonTest/kotlin/fr/acinq/phoenix/db/SqliteChannelsDbTest.kt +++ b/phoenix-shared/src/commonTest/kotlin/fr/acinq/phoenix/db/SqliteChannelsDbTest.kt @@ -19,6 +19,7 @@ package fr.acinq.phoenix.db import app.cash.sqldelight.db.SqlDriver import fr.acinq.bitcoin.ByteVector32 import fr.acinq.phoenix.utils.runTest +import fr.acinq.phoenix.utils.testLoggerFactory import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi @@ -46,7 +47,7 @@ class SqliteChannelsDbTest : UsingContextTest() { @Test fun `read v1 db`() = runTest { val driver = createChannelsDbDriver(getPlatformContext(), fileName = "channels-testnet-fe646b99.sqlite") - val channelsDb = createSqliteChannelsDb(driver) + val channelsDb = createSqliteChannelsDb(driver, testLoggerFactory) val channels = channelsDb.listLocalChannels()