Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ import fr.acinq.phoenix.db.sqldelight.*
import fr.acinq.phoenix.managers.PaymentMetadataQueue
import fr.acinq.phoenix.utils.extensions.toByteArray

fun createSqliteChannelsDb(driver: SqlDriver): SqliteChannelsDb {
fun createSqliteChannelsDb(driver: SqlDriver, loggerFactory: LoggerFactory): SqliteChannelsDb {
return SqliteChannelsDb(
driver = driver,
database = ChannelsDatabase(
driver = driver,
htlc_infosAdapter = Htlc_infos.Adapter(ByteVector32Adapter, ByteVector32Adapter),
local_channelsAdapter = Local_channels.Adapter(ByteVector32Adapter, PersistedChannelStateAdapter)
)
local_channelsAdapter = Local_channels.Adapter(ByteVector32Adapter),
),
loggerFactory = loggerFactory
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,36 @@ import fr.acinq.bitcoin.ByteVector32
import fr.acinq.lightning.CltvExpiry
import fr.acinq.lightning.channel.states.PersistedChannelState
import fr.acinq.lightning.db.ChannelsDb
import fr.acinq.lightning.logging.LoggerFactory
import fr.acinq.phoenix.db.sqldelight.ChannelsDatabase
import kotlin.collections.List
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext

class SqliteChannelsDb(val driver: SqlDriver, database: ChannelsDatabase) : ChannelsDb {
class SqliteChannelsDb(val driver: SqlDriver, database: ChannelsDatabase, loggerFactory: LoggerFactory) : ChannelsDb {

val log = loggerFactory.newLogger(this::class)
private val queries = database.channelsDatabaseQueries

override suspend fun addOrUpdateChannel(state: PersistedChannelState) {
withContext(Dispatchers.Default) {
queries.transaction {
queries.getChannel(state.channelId).executeAsOneOrNull()?.run {
queries.updateChannel(channel_id = state.channelId, data_ = state)
queries.updateChannel(channel_id = state.channelId, data_ = PersistedChannelStateAdapter.encode(state))
} ?: run {
queries.insertChannel(channel_id = state.channelId, data_ = state)
queries.insertChannel(channel_id = state.channelId, data_ = PersistedChannelStateAdapter.encode(state))
}
}
}
}

suspend fun getChannel(channelId: ByteVector32): Triple<ByteVector32, PersistedChannelState, Boolean>? {
return withContext(Dispatchers.Default) {
queries.getChannel(channelId, mapper = { channelId, data, isClosed ->
Triple(channelId, data, isClosed)
}).executeAsOneOrNull()
queries.getChannel(channelId).executeAsOneOrNull()?.let { (channelId, data, isClosed) ->
mapChannelData(channelId, data)?.let {
Triple(channelId, it, isClosed)
}
}
}
}

Expand All @@ -58,7 +62,7 @@ class SqliteChannelsDb(val driver: SqlDriver, database: ChannelsDatabase) : Chan
}

override suspend fun listLocalChannels(): List<PersistedChannelState> = withContext(Dispatchers.Default) {
queries.listLocalChannels().executeAsList()
queries.listLocalChannels().executeAsList().mapNotNull { (channelId, data) -> mapChannelData(channelId, data) }
}

override suspend fun addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry) {
Expand All @@ -80,6 +84,15 @@ class SqliteChannelsDb(val driver: SqlDriver, database: ChannelsDatabase) : Chan
}
}

private fun mapChannelData(channelId: ByteVector32, data: ByteArray): PersistedChannelState? {
return try {
PersistedChannelStateAdapter.decode(data)
} catch (e: Exception) {
log.e(e) { "failed to read channel data for channel=$channelId :" }
null
}
}

override fun close() {
driver.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class DatabaseManager(
log.debug { "nodeParams available: building databases..." }

val channelsDbDriver = createChannelsDbDriver(ctx, channelsDbName(chain, nodeParams.nodeId))
val channelsDb = createSqliteChannelsDb(channelsDbDriver)
val channelsDb = createSqliteChannelsDb(channelsDbDriver, loggerFactory)
val paymentsDbDriver = createPaymentsDbDriver(ctx, paymentsDbName(chain, nodeParams.nodeId)) { log.e { "payments-db migration error: $it" } }
val paymentsDb = createSqlitePaymentsDb(paymentsDbDriver, paymentMetadataQueue, loggerFactory)
val cloudKitDb = makeCloudKitDb(appDb, paymentsDb)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import fr.acinq.bitcoin.ByteVector32;
import fr.acinq.lightning.channel.states.PersistedChannelState;
import kotlin.Boolean;

-- channels table
-- note: boolean are stored as INTEGER, with 0=false
CREATE TABLE local_channels (
channel_id BLOB AS ByteVector32 NOT NULL PRIMARY KEY,
data BLOB AS PersistedChannelState NOT NULL,
data BLOB NOT NULL,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't immediately assume the data can be decoded anymore, so that we can catch and log decoding errors.

is_closed INTEGER AS Boolean DEFAULT 0 NOT NULL
);

Expand Down Expand Up @@ -35,7 +34,7 @@ closeLocalChannel:
UPDATE local_channels SET is_closed=1 WHERE channel_id=?;

listLocalChannels:
SELECT data FROM local_channels WHERE is_closed=0;
SELECT channel_id, data FROM local_channels WHERE is_closed=0;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the channel_id here allows us to reuse in listLocalChannels the mapChannelData already used by getChannel. Also, lets us log the id of the faulty channel.


-- htlcs info queries
insertHtlcInfo:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package fr.acinq.phoenix.db
import app.cash.sqldelight.db.SqlDriver
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.phoenix.utils.runTest
import fr.acinq.phoenix.utils.testLoggerFactory
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
Expand Down Expand Up @@ -46,7 +47,7 @@ class SqliteChannelsDbTest : UsingContextTest() {
@Test
fun `read v1 db`() = runTest {
val driver = createChannelsDbDriver(getPlatformContext(), fileName = "channels-testnet-fe646b99.sqlite")
val channelsDb = createSqliteChannelsDb(driver)
val channelsDb = createSqliteChannelsDb(driver, testLoggerFactory)

val channels = channelsDb.listLocalChannels()

Expand Down