Skip to content

Commit 4dc2910

Browse files
authored
Make result set an iterable (#1823)
This allows us to use the full power of scala collections, to iterate over results, convert to options, etc. while staying purely functional and immutable. There is a catch though: the iterator is lazy, it must be materialized before the result set is closed, by converting the end result in a collection or an option. In other words, database methods must never return an `Iterable` or `Iterator`.
1 parent f829a2e commit 4dc2910

23 files changed

+363
-490
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import akka.actor.{Actor, ActorLogging, Props}
2020
import fr.acinq.bitcoin.Crypto.PublicKey
2121
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
2222
import fr.acinq.eclair.NodeParams
23-
import fr.acinq.eclair.channel.Helpers.Closing.{ClosingType, CurrentRemoteClose, LocalClose, MutualClose, NextRemoteClose, RecoveryClose, RevokedClose}
23+
import fr.acinq.eclair.channel.Helpers.Closing._
2424
import fr.acinq.eclair.channel.Monitoring.{Metrics => ChannelMetrics, Tags => ChannelTags}
2525
import fr.acinq.eclair.channel._
2626
import fr.acinq.eclair.db.DbEventHandler.ChannelEvent

eclair-core/src/main/scala/fr/acinq/eclair/db/FeeratesDb.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
package fr.acinq.eclair.db
1818

19-
import java.io.Closeable
20-
2119
import fr.acinq.eclair.blockchain.fee.FeeratesPerKB
2220

21+
import java.io.Closeable
22+
2323
/**
2424
* This database stores the fee rates retrieved by a [[fr.acinq.eclair.blockchain.fee.FeeProvider]].
2525
*/

eclair-core/src/main/scala/fr/acinq/eclair/db/FileBackupHandler.scala

+2-3
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,15 @@
1616

1717
package fr.acinq.eclair.db
1818

19-
import java.io.File
20-
import java.nio.file.{Files, StandardCopyOption}
21-
2219
import akka.actor.{Actor, ActorLogging, Props}
2320
import akka.dispatch.{BoundedMessageQueueSemantics, RequiresMessageQueue}
2421
import fr.acinq.eclair.KamonExt
2522
import fr.acinq.eclair.channel.ChannelPersisted
2623
import fr.acinq.eclair.db.Databases.FileBackup
2724
import fr.acinq.eclair.db.Monitoring.Metrics
2825

26+
import java.io.File
27+
import java.nio.file.{Files, StandardCopyOption}
2928
import scala.sys.process.Process
3029
import scala.util.{Failure, Success, Try}
3130

eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616

1717
package fr.acinq.eclair.db
1818

19-
import java.io.Closeable
20-
2119
import fr.acinq.bitcoin.Crypto.PublicKey
2220
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
2321
import fr.acinq.eclair.ShortChannelId
2422
import fr.acinq.eclair.router.Router.PublicChannel
2523
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}
2624

25+
import java.io.Closeable
2726
import scala.collection.immutable.SortedMap
2827

2928
trait NetworkDb extends Closeable {

eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717
package fr.acinq.eclair.db
1818

19-
import java.io.Closeable
20-
import java.util.UUID
21-
2219
import fr.acinq.bitcoin.ByteVector32
2320
import fr.acinq.bitcoin.Crypto.PublicKey
2421
import fr.acinq.eclair.payment._
2522
import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop}
2623
import fr.acinq.eclair.{MilliSatoshi, ShortChannelId}
2724

25+
import java.io.Closeable
26+
import java.util.UUID
27+
2828
trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable
2929

3030
trait IncomingPaymentsDb {

eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
package fr.acinq.eclair.db
1818

19-
import java.io.Closeable
20-
2119
import fr.acinq.bitcoin.Crypto.PublicKey
2220
import fr.acinq.eclair.wire.protocol.NodeAddress
2321

22+
import java.io.Closeable
23+
2424
trait PeersDb extends Closeable {
2525

2626
def addOrUpdatePeer(nodeId: PublicKey, address: NodeAddress): Unit

eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala

+24-18
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@ package fr.acinq.eclair.db.jdbc
1919
import fr.acinq.bitcoin.ByteVector32
2020
import fr.acinq.eclair.MilliSatoshi
2121
import org.sqlite.SQLiteConnection
22-
import scodec.Codec
22+
import scodec.Decoder
2323
import scodec.bits.{BitVector, ByteVector}
2424

2525
import java.sql.{Connection, ResultSet, Statement, Timestamp}
2626
import java.util.UUID
2727
import javax.sql.DataSource
28-
import scala.collection.immutable.Queue
2928

3029
trait JdbcUtils {
3130

31+
import ExtendedResultSet._
32+
3233
def withConnection[T](f: Connection => T)(implicit dataSource: DataSource): T = {
3334
val connection = dataSource.getConnection()
3435
try {
@@ -72,15 +73,16 @@ trait JdbcUtils {
7273
def getVersion(statement: Statement, db_name: String): Option[Int] = {
7374
createVersionTable(statement)
7475
// if there was a previous version installed, this will return a different value from current version
75-
val rs = statement.executeQuery(s"SELECT version FROM versions WHERE db_name='$db_name'")
76-
if (rs.next()) Some(rs.getInt("version")) else None
76+
statement.executeQuery(s"SELECT version FROM versions WHERE db_name='$db_name'")
77+
.map(rs => rs.getInt("version"))
78+
.headOption
7779
}
7880

7981
/**
8082
* Updates the version for a particular logical database, it will overwrite the previous version.
8183
*
8284
* NB: we could define this method in [[fr.acinq.eclair.db.sqlite.SqliteUtils]] and [[fr.acinq.eclair.db.pg.PgUtils]]
83-
* but it would make testing more complicated because we need to use one or the other depending on the backend.
85+
* but it would make testing more complicated because we need to use one or the other depending on the backend.
8486
*/
8587
def setVersion(statement: Statement, db_name: String, newVersion: Int): Unit = {
8688
createVersionTable(statement)
@@ -96,20 +98,25 @@ trait JdbcUtils {
9698
}
9799
}
98100

99-
/**
100-
* This helper assumes that there is a "data" column available, decodable with the provided codec
101-
*
102-
* TODO: we should use an scala.Iterator instead
103-
*/
104-
def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = {
105-
var q: Queue[T] = Queue()
106-
while (rs.next()) {
107-
q = q :+ codec.decode(BitVector(rs.getBytes("data"))).require.value
101+
case class ExtendedResultSet(rs: ResultSet) extends Iterable[ResultSet] {
102+
103+
/**
104+
* Iterates over all rows of a result set.
105+
*
106+
* Careful: the iterator is lazy, it must be materialized before the [[ResultSet]] is closed, by converting the end
107+
* result in a collection or an option.
108+
*/
109+
override def iterator: Iterator[ResultSet] = {
110+
// @formatter:off
111+
new Iterator[ResultSet] {
112+
def hasNext: Boolean = rs.next()
113+
def next(): ResultSet = rs
114+
}
115+
// @formatter:on
108116
}
109-
q
110-
}
111117

112-
case class ExtendedResultSet(rs: ResultSet) {
118+
/** This helper assumes that there is a "data" column available, that can be decoded with the provided codec */
119+
def mapCodec[T](codec: Decoder[T]): Iterable[T] = rs.map(rs => codec.decode(BitVector(rs.getBytes("data"))).require.value)
113120

114121
def getByteVectorFromHex(columnLabel: String): ByteVector = {
115122
val s = rs.getString(columnLabel).stripPrefix("\\x")
@@ -166,7 +173,6 @@ trait JdbcUtils {
166173
val result = rs.getTimestamp(label)
167174
if (rs.wasNull()) None else Some(result)
168175
}
169-
170176
}
171177

172178
object ExtendedResultSet {

eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala

+76-86
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ import java.sql.{Statement, Timestamp}
3333
import java.time.Instant
3434
import java.util.UUID
3535
import javax.sql.DataSource
36-
import scala.collection.immutable.Queue
3736

3837
class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
3938

@@ -215,30 +214,28 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
215214
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement =>
216215
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
217216
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
218-
val rs = statement.executeQuery()
219-
var sentByParentId = Map.empty[UUID, PaymentSent]
220-
while (rs.next()) {
221-
val parentId = UUID.fromString(rs.getString("parent_payment_id"))
222-
val part = PaymentSent.PartialPayment(
223-
UUID.fromString(rs.getString("payment_id")),
224-
MilliSatoshi(rs.getLong("amount_msat")),
225-
MilliSatoshi(rs.getLong("fees_msat")),
226-
rs.getByteVector32FromHex("to_channel_id"),
227-
None, // we don't store the route in the audit DB
228-
rs.getTimestamp("timestamp").getTime)
229-
val sent = sentByParentId.get(parentId) match {
230-
case Some(s) => s.copy(parts = s.parts :+ part)
231-
case None => PaymentSent(
232-
parentId,
233-
rs.getByteVector32FromHex("payment_hash"),
234-
rs.getByteVector32FromHex("payment_preimage"),
235-
MilliSatoshi(rs.getLong("recipient_amount_msat")),
236-
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
237-
Seq(part))
238-
}
239-
sentByParentId = sentByParentId + (parentId -> sent)
240-
}
241-
sentByParentId.values.toSeq.sortBy(_.timestamp)
217+
statement.executeQuery()
218+
.foldLeft(Map.empty[UUID, PaymentSent]) { (sentByParentId, rs) =>
219+
val parentId = UUID.fromString(rs.getString("parent_payment_id"))
220+
val part = PaymentSent.PartialPayment(
221+
UUID.fromString(rs.getString("payment_id")),
222+
MilliSatoshi(rs.getLong("amount_msat")),
223+
MilliSatoshi(rs.getLong("fees_msat")),
224+
rs.getByteVector32FromHex("to_channel_id"),
225+
None, // we don't store the route in the audit DB
226+
rs.getTimestamp("timestamp").getTime)
227+
val sent = sentByParentId.get(parentId) match {
228+
case Some(s) => s.copy(parts = s.parts :+ part)
229+
case None => PaymentSent(
230+
parentId,
231+
rs.getByteVector32FromHex("payment_hash"),
232+
rs.getByteVector32FromHex("payment_preimage"),
233+
MilliSatoshi(rs.getLong("recipient_amount_msat")),
234+
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
235+
Seq(part))
236+
}
237+
sentByParentId + (parentId -> sent)
238+
}.values.toSeq.sortBy(_.timestamp)
242239
}
243240
}
244241

@@ -247,98 +244,91 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
247244
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement =>
248245
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
249246
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
250-
val rs = statement.executeQuery()
251-
var receivedByHash = Map.empty[ByteVector32, PaymentReceived]
252-
while (rs.next()) {
253-
val paymentHash = rs.getByteVector32FromHex("payment_hash")
254-
val part = PaymentReceived.PartialPayment(
255-
MilliSatoshi(rs.getLong("amount_msat")),
256-
rs.getByteVector32FromHex("from_channel_id"),
257-
rs.getTimestamp("timestamp").getTime)
258-
val received = receivedByHash.get(paymentHash) match {
259-
case Some(r) => r.copy(parts = r.parts :+ part)
260-
case None => PaymentReceived(paymentHash, Seq(part))
261-
}
262-
receivedByHash = receivedByHash + (paymentHash -> received)
263-
}
264-
receivedByHash.values.toSeq.sortBy(_.timestamp)
247+
statement.executeQuery()
248+
.foldLeft(Map.empty[ByteVector32, PaymentReceived]) { (receivedByHash, rs) =>
249+
val paymentHash = rs.getByteVector32FromHex("payment_hash")
250+
val part = PaymentReceived.PartialPayment(
251+
MilliSatoshi(rs.getLong("amount_msat")),
252+
rs.getByteVector32FromHex("from_channel_id"),
253+
rs.getTimestamp("timestamp").getTime)
254+
val received = receivedByHash.get(paymentHash) match {
255+
case Some(r) => r.copy(parts = r.parts :+ part)
256+
case None => PaymentReceived(paymentHash, Seq(part))
257+
}
258+
receivedByHash + (paymentHash -> received)
259+
}.values.toSeq.sortBy(_.timestamp)
265260
}
266261
}
267262

268263
override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] =
269264
inTransaction { pg =>
270-
var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]
271-
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement =>
265+
val trampolineByHash = using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement =>
272266
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
273267
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
274-
val rs = statement.executeQuery()
275-
while (rs.next()) {
276-
val paymentHash = rs.getByteVector32FromHex("payment_hash")
277-
val amount = MilliSatoshi(rs.getLong("amount_msat"))
278-
val nodeId = PublicKey(rs.getByteVectorFromHex("next_node_id"))
279-
trampolineByHash += (paymentHash -> (amount, nodeId))
280-
}
268+
statement.executeQuery()
269+
.foldLeft(Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]) { (trampolineByHash, rs) =>
270+
val paymentHash = rs.getByteVector32FromHex("payment_hash")
271+
val amount = MilliSatoshi(rs.getLong("amount_msat"))
272+
val nodeId = PublicKey(rs.getByteVectorFromHex("next_node_id"))
273+
trampolineByHash + (paymentHash -> (amount, nodeId))
274+
}
281275
}
282-
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement =>
276+
val relayedByHash = using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement =>
283277
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
284278
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
285-
val rs = statement.executeQuery()
286-
var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]]
287-
while (rs.next()) {
288-
val paymentHash = rs.getByteVector32FromHex("payment_hash")
289-
val part = RelayedPart(
290-
rs.getByteVector32FromHex("channel_id"),
291-
MilliSatoshi(rs.getLong("amount_msat")),
292-
rs.getString("direction"),
293-
rs.getString("relay_type"),
294-
rs.getTimestamp("timestamp").getTime)
295-
relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
296-
}
297-
relayedByHash.flatMap {
298-
case (paymentHash, parts) =>
299-
// We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel).
300-
// NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch.
301-
val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
302-
val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
303-
parts.headOption match {
304-
case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map {
305-
case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp)
306-
}
307-
case Some(RelayedPart(_, _, _, "trampoline", timestamp)) =>
308-
val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey))
309-
TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, timestamp) :: Nil
310-
case _ => Nil
311-
}
312-
}.toSeq.sortBy(_.timestamp)
279+
statement.executeQuery()
280+
.foldLeft(Map.empty[ByteVector32, Seq[RelayedPart]]) { (relayedByHash, rs) =>
281+
val paymentHash = rs.getByteVector32FromHex("payment_hash")
282+
val part = RelayedPart(
283+
rs.getByteVector32FromHex("channel_id"),
284+
MilliSatoshi(rs.getLong("amount_msat")),
285+
rs.getString("direction"),
286+
rs.getString("relay_type"),
287+
rs.getTimestamp("timestamp").getTime)
288+
relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
289+
}
313290
}
291+
relayedByHash.flatMap {
292+
case (paymentHash, parts) =>
293+
// We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel).
294+
// NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch.
295+
val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
296+
val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
297+
parts.headOption match {
298+
case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map {
299+
case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp)
300+
}
301+
case Some(RelayedPart(_, _, _, "trampoline", timestamp)) =>
302+
val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey))
303+
TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, timestamp) :: Nil
304+
case _ => Nil
305+
}
306+
}.toSeq.sortBy(_.timestamp)
314307
}
315308

316309
override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] =
317310
inTransaction { pg =>
318311
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement =>
319312
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
320313
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
321-
val rs = statement.executeQuery()
322-
var q: Queue[NetworkFee] = Queue()
323-
while (rs.next()) {
324-
q = q :+ NetworkFee(
314+
statement.executeQuery().map { rs =>
315+
NetworkFee(
325316
remoteNodeId = PublicKey(rs.getByteVectorFromHex("node_id")),
326317
channelId = rs.getByteVector32FromHex("channel_id"),
327318
txId = rs.getByteVector32FromHex("tx_id"),
328319
fee = Satoshi(rs.getLong("fee_sat")),
329320
txType = rs.getString("tx_type"),
330321
timestamp = rs.getTimestamp("timestamp").getTime)
331-
}
332-
q
322+
}.toSeq
333323
}
334324
}
335325

336326
override def stats(from: Long, to: Long): Seq[Stats] = {
337-
val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) =>
327+
val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { (feeByChannelId, f) =>
338328
feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee))
339329
}
340330
case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String)
341-
val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { case (previous, e) =>
331+
val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { (previous, e) =>
342332
// NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones.
343333
val current = e match {
344334
case c: ChannelPaymentRelayed => Map(

0 commit comments

Comments
 (0)