Skip to content

Commit 6518bb4

Browse files
authored
Posgres: fix concurrency in channels db (#1762)
* preserve pg lock exception cause * specialize connections by backend type * added concurrency test on channels table This test unveils a concurrency issue in the upsert logic of the local channels db, with the following error being thrown when we update many channels concurrently: ``` Canceled on identification as a pivot, during conflict out checking ``` * use pg upsert construct This is the recommended pattern according to postgres doc (https://www.postgresql.org/docs/current/plpgsql-control-structures.html#PLPGSQL-UPSERT-EXAMPLE): > It is recommended that applications use INSERT with ON CONFLICT DO UPDATE rather than actually using this pattern. * reproduce and fix same issue in peers db
1 parent 1e2abae commit 6518bb4

File tree

6 files changed

+89
-40
lines changed

6 files changed

+89
-40
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala

+10-10
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,16 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit
6767
override def addOrUpdateChannel(state: HasCommitments): Unit = withMetrics("channels/add-or-update-channel", DbBackends.Postgres) {
6868
withLock { pg =>
6969
val data = stateDataCodec.encode(state).require.toByteArray
70-
using(pg.prepareStatement("UPDATE local_channels SET data=? WHERE channel_id=?")) { update =>
71-
update.setBytes(1, data)
72-
update.setString(2, state.channelId.toHex)
73-
if (update.executeUpdate() == 0) {
74-
using(pg.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, FALSE)")) { statement =>
75-
statement.setString(1, state.channelId.toHex)
76-
statement.setBytes(2, data)
77-
statement.executeUpdate()
78-
}
79-
}
70+
using(pg.prepareStatement(
71+
"""
72+
| INSERT INTO local_channels (channel_id, data, is_closed)
73+
| VALUES (?, ?, FALSE)
74+
| ON CONFLICT (channel_id)
75+
| DO UPDATE SET data = EXCLUDED.data ;
76+
| """.stripMargin)) { statement =>
77+
statement.setString(1, state.channelId.toHex)
78+
statement.setBytes(2, data)
79+
statement.executeUpdate()
8080
}
8181
}
8282
}

eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala

+10-10
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb {
4646
override def addOrUpdatePeer(nodeId: Crypto.PublicKey, nodeaddress: NodeAddress): Unit = withMetrics("peers/add-or-update", DbBackends.Postgres) {
4747
withLock { pg =>
4848
val data = CommonCodecs.nodeaddress.encode(nodeaddress).require.toByteArray
49-
using(pg.prepareStatement("UPDATE peers SET data=? WHERE node_id=?")) { update =>
50-
update.setBytes(1, data)
51-
update.setString(2, nodeId.value.toHex)
52-
if (update.executeUpdate() == 0) {
53-
using(pg.prepareStatement("INSERT INTO peers VALUES (?, ?)")) { statement =>
54-
statement.setString(1, nodeId.value.toHex)
55-
statement.setBytes(2, data)
56-
statement.executeUpdate()
57-
}
58-
}
49+
using(pg.prepareStatement(
50+
"""
51+
| INSERT INTO peers (node_id, data)
52+
| VALUES (?, ?)
53+
| ON CONFLICT (node_id)
54+
| DO UPDATE SET data = EXCLUDED.data ;
55+
| """.stripMargin)) { statement =>
56+
statement.setString(1, nodeId.value.toHex)
57+
statement.setBytes(2, data)
58+
statement.executeUpdate()
5959
}
6060
}
6161
}

eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ object PgUtils extends JdbcUtils {
6060
logger.error(s"cannot obtain lock on the database ($other).")
6161
}
6262

63-
case class LockException(lockFailure: LockFailure) extends RuntimeException("a lock exception occurred")
63+
case class LockException(lockFailure: LockFailure) extends RuntimeException("a lock exception occurred", lockFailure match {
64+
case LockFailure.GeneralLockException(cause) => cause // when the origin is an exception, we provide it to have a nice stack trace
65+
case _ => null
66+
})
6467

6568
/**
6669
* This handler is useful in tests

eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala

+9-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import fr.acinq.eclair.db.pg.PgUtils.PgLock
88
import fr.acinq.eclair.db.sqlite.SqliteUtils
99
import fr.acinq.eclair.db._
1010
import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler
11+
import org.postgresql.jdbc.PgConnection
12+
import org.sqlite.SQLiteConnection
1113

1214
import java.io.File
1315
import java.sql.{Connection, DriverManager, Statement}
@@ -36,13 +38,16 @@ sealed trait TestDatabases extends Databases {
3638

3739
object TestDatabases {
3840

39-
def sqliteInMemory(): Connection = DriverManager.getConnection("jdbc:sqlite::memory:")
41+
def sqliteInMemory(): SQLiteConnection = DriverManager.getConnection("jdbc:sqlite::memory:").asInstanceOf[SQLiteConnection]
4042

41-
def inMemoryDb(connection: Connection = sqliteInMemory()): Databases = Databases.SqliteDatabases(connection, connection, connection)
43+
def inMemoryDb(): Databases = {
44+
val connection = sqliteInMemory()
45+
Databases.SqliteDatabases(connection, connection, connection)
46+
}
4247

4348
case class TestSqliteDatabases() extends TestDatabases {
4449
// @formatter:off
45-
override val connection: Connection = sqliteInMemory()
50+
override val connection: SQLiteConnection = sqliteInMemory()
4651
override lazy val db: Databases = Databases.SqliteDatabases(connection, connection, connection)
4752
override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = SqliteUtils.getVersion(statement, db_name, currentVersion)
4853
override def close(): Unit = ()
@@ -62,7 +67,7 @@ object TestDatabases {
6267
implicit val system: ActorSystem = ActorSystem()
6368

6469
// @formatter:off
65-
override val connection: Connection = pg.getPostgresDatabase.getConnection
70+
override val connection: PgConnection = pg.getPostgresDatabase.getConnection.asInstanceOf[PgConnection]
6671
override lazy val db: Databases = Databases.PostgresDatabases(hikariConfig, UUID.randomUUID(), lock, jdbcUrlFile_opt = Some(jdbcUrlFile))
6772
override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = PgUtils.getVersion(statement, db_name, currentVersion)
6873
override def close(): Unit = pg.close()

eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala

+39-15
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@ import fr.acinq.eclair.db.sqlite.SqliteChannelsDb
2626
import fr.acinq.eclair.db.sqlite.SqliteUtils.ExtendedResultSet._
2727
import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec
2828
import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec
29-
import fr.acinq.eclair.{CltvExpiry, randomBytes32}
29+
import fr.acinq.eclair.{CltvExpiry, ShortChannelId, randomBytes32}
3030
import org.scalatest.funsuite.AnyFunSuite
3131
import scodec.bits.ByteVector
3232

3333
import java.sql.SQLException
34+
import java.util.concurrent.Executors
35+
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
36+
import scala.concurrent.duration._
3437

3538
class ChannelsDbSpec extends AnyFunSuite {
3639

@@ -52,30 +55,51 @@ class ChannelsDbSpec extends AnyFunSuite {
5255
val db = dbs.channels
5356
dbs.pendingRelay // needed by db.removeChannel
5457

55-
val channel = ChannelCodecsSpec.normal
58+
val channel1 = ChannelCodecsSpec.normal
59+
val channel2a = ChannelCodecsSpec.normal.modify(_.commitments.channelId).setTo(randomBytes32)
60+
val channel2b = channel2a.modify(_.shortChannelId).setTo(ShortChannelId(189371))
5661

5762
val commitNumber = 42
5863
val paymentHash1 = ByteVector32.Zeroes
5964
val cltvExpiry1 = CltvExpiry(123)
6065
val paymentHash2 = ByteVector32(ByteVector.fill(32)(1))
6166
val cltvExpiry2 = CltvExpiry(656)
6267

63-
intercept[SQLException](db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)) // no related channel
68+
intercept[SQLException](db.addHtlcInfo(channel1.channelId, commitNumber, paymentHash1, cltvExpiry1)) // no related channel
6469

6570
assert(db.listLocalChannels().toSet === Set.empty)
66-
db.addOrUpdateChannel(channel)
67-
db.addOrUpdateChannel(channel)
68-
assert(db.listLocalChannels() === List(channel))
69-
70-
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
71-
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)
72-
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash2, cltvExpiry2)
73-
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList.toSet == Set((paymentHash1, cltvExpiry1), (paymentHash2, cltvExpiry2)))
74-
assert(db.listHtlcInfos(channel.channelId, 43).toList == Nil)
75-
76-
db.removeChannel(channel.channelId)
71+
db.addOrUpdateChannel(channel1)
72+
db.addOrUpdateChannel(channel1)
73+
assert(db.listLocalChannels() === List(channel1))
74+
db.addOrUpdateChannel(channel2a)
75+
assert(db.listLocalChannels() === List(channel1, channel2a))
76+
db.addOrUpdateChannel(channel2b)
77+
assert(db.listLocalChannels() === List(channel1, channel2b))
78+
79+
assert(db.listHtlcInfos(channel1.channelId, commitNumber).toList == Nil)
80+
db.addHtlcInfo(channel1.channelId, commitNumber, paymentHash1, cltvExpiry1)
81+
db.addHtlcInfo(channel1.channelId, commitNumber, paymentHash2, cltvExpiry2)
82+
assert(db.listHtlcInfos(channel1.channelId, commitNumber).toList.toSet == Set((paymentHash1, cltvExpiry1), (paymentHash2, cltvExpiry2)))
83+
assert(db.listHtlcInfos(channel1.channelId, 43).toList == Nil)
84+
85+
db.removeChannel(channel1.channelId)
86+
assert(db.listLocalChannels() === List(channel2b))
87+
assert(db.listHtlcInfos(channel1.channelId, commitNumber).toList == Nil)
88+
db.removeChannel(channel2b.channelId)
7789
assert(db.listLocalChannels() === Nil)
78-
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
90+
}
91+
}
92+
93+
test("concurrent channel updates") {
94+
forAllDbs { dbs =>
95+
val db = dbs.channels
96+
implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(8))
97+
val channel = ChannelCodecsSpec.normal
98+
val futures = for (_ <- 0 until 10000) yield {
99+
Future(db.addOrUpdateChannel(channel.modify(_.commitments.channelId).setTo(randomBytes32)))
100+
}
101+
val res = Future.sequence(futures)
102+
Await.result(res, 60 seconds)
79103
}
80104
}
81105

eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala

+17
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ import fr.acinq.eclair.randomKey
2424
import fr.acinq.eclair.wire.protocol.{NodeAddress, Tor2, Tor3}
2525
import org.scalatest.funsuite.AnyFunSuite
2626

27+
import java.util.concurrent.Executors
28+
import scala.concurrent.duration._
29+
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
30+
import scala.util.Success
2731

2832
class PeersDbSpec extends AnyFunSuite {
2933

@@ -68,4 +72,17 @@ class PeersDbSpec extends AnyFunSuite {
6872
}
6973
}
7074

75+
test("concurrent peer updates") {
76+
forAllDbs { dbs =>
77+
val db = dbs.peers
78+
implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(8))
79+
val Success(peerAddress) = NodeAddress.fromParts("127.0.0.1", 42000)
80+
val futures = for (_ <- 0 until 10000) yield {
81+
Future(db.addOrUpdatePeer(randomKey.publicKey, peerAddress))
82+
}
83+
val res = Future.sequence(futures)
84+
Await.result(res, 60 seconds)
85+
}
86+
}
87+
7188
}

0 commit comments

Comments
 (0)