Skip to content

Commit e14c40d

Browse files
authored
Use proper data type for timestamps in Postgres (#1778)
Did some refactoring in tests and introduced a new `migrationCheck` helper method. Note that the change of data type in sqlite for the `commitment_number` field (from `BLOB` to `INTEGER`) is not a migration. If the table has been created before, it will stay like it was. It doesn't matter due to how sqlite stores data, and we make sure in tests that there is no regression.
1 parent 4a1dfd2 commit e14c40d

File tree

7 files changed

+534
-395
lines changed

7 files changed

+534
-395
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala

+9-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.sqlite.SQLiteConnection
2222
import scodec.Codec
2323
import scodec.bits.{BitVector, ByteVector}
2424

25-
import java.sql.{Connection, ResultSet, Statement}
25+
import java.sql.{Connection, ResultSet, Statement, Timestamp}
2626
import java.util.UUID
2727
import javax.sql.DataSource
2828
import scala.collection.immutable.Queue
@@ -123,18 +123,16 @@ trait JdbcUtils {
123123

124124
def getByteVector32FromHexNullable(columnLabel: String): Option[ByteVector32] = {
125125
val s = rs.getString(columnLabel)
126-
if (rs.wasNull()) None else {
127-
Some(ByteVector32(ByteVector.fromValidHex(s)))
128-
}
126+
if (rs.wasNull()) None else Some(ByteVector32(ByteVector.fromValidHex(s)))
129127
}
130128

131129
def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))
132130

133131
def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))
134132

135-
def getByteVectorNullable(columnLabel: String): ByteVector = {
133+
def getByteVectorNullable(columnLabel: String): Option[ByteVector] = {
136134
val result = rs.getBytes(columnLabel)
137-
if (rs.wasNull()) ByteVector.empty else ByteVector(result)
135+
if (rs.wasNull()) None else Some(ByteVector(result))
138136
}
139137

140138
def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
@@ -164,6 +162,11 @@ trait JdbcUtils {
164162
if (rs.wasNull()) None else Some(MilliSatoshi(result))
165163
}
166164

165+
def getTimestampNullable(label: String): Option[Timestamp] = {
166+
val result = rs.getTimestamp(label)
167+
if (rs.wasNull()) None else Some(result)
168+
}
169+
167170
}
168171

169172
object ExtendedResultSet {

eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala

+50-35
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ import fr.acinq.eclair.transactions.Transactions.PlaceHolderPubKey
2929
import fr.acinq.eclair.{MilliSatoshi, MilliSatoshiLong}
3030
import grizzled.slf4j.Logging
3131

32-
import java.sql.Statement
32+
import java.sql.{Statement, Timestamp}
33+
import java.time.Instant
3334
import java.util.UUID
3435
import javax.sql.DataSource
3536
import scala.collection.immutable.Queue
@@ -40,7 +41,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
4041
import ExtendedResultSet._
4142

4243
val DB_NAME = "audit"
43-
val CURRENT_VERSION = 5
44+
val CURRENT_VERSION = 6
4445

4546
case class RelayedPart(channelId: ByteVector32, amount: MilliSatoshi, direction: String, relayType: String, timestamp: Long)
4647

@@ -52,15 +53,25 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
5253
statement.executeUpdate("CREATE INDEX relayed_trampoline_payment_hash_idx ON relayed_trampoline(payment_hash)")
5354
}
5455

56+
def migration56(statement: Statement): Unit = {
57+
statement.executeUpdate("ALTER TABLE sent ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
58+
statement.executeUpdate("ALTER TABLE received ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
59+
statement.executeUpdate("ALTER TABLE relayed ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
60+
statement.executeUpdate("ALTER TABLE relayed_trampoline ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
61+
statement.executeUpdate("ALTER TABLE network_fees ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
62+
statement.executeUpdate("ALTER TABLE channel_events ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
63+
statement.executeUpdate("ALTER TABLE channel_errors ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
64+
}
65+
5566
getVersion(statement, DB_NAME) match {
5667
case None =>
57-
statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
58-
statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
59-
statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
60-
statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
61-
statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
62-
statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)")
63-
statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)")
68+
statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
69+
statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
70+
statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
71+
statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
72+
statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
73+
statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
74+
statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
6475

6576
statement.executeUpdate("CREATE INDEX sent_timestamp_idx ON sent(timestamp)")
6677
statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)")
@@ -74,6 +85,10 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
7485
case Some(v@4) =>
7586
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
7687
migration45(statement)
88+
migration56(statement)
89+
case Some(v@5) =>
90+
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
91+
migration56(statement)
7792
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
7893
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
7994
}
@@ -90,7 +105,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
90105
statement.setBoolean(4, e.isFunder)
91106
statement.setBoolean(5, e.isPrivate)
92107
statement.setString(6, e.event.label)
93-
statement.setLong(7, System.currentTimeMillis)
108+
statement.setTimestamp(7, Timestamp.from(Instant.now()))
94109
statement.executeUpdate()
95110
}
96111
}
@@ -109,7 +124,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
109124
statement.setString(7, e.paymentPreimage.toHex)
110125
statement.setString(8, e.recipientNodeId.value.toHex)
111126
statement.setString(9, p.toChannelId.toHex)
112-
statement.setLong(10, p.timestamp)
127+
statement.setTimestamp(10, Timestamp.from(Instant.ofEpochMilli(p.timestamp)))
113128
statement.addBatch()
114129
})
115130
statement.executeBatch()
@@ -124,7 +139,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
124139
statement.setLong(1, p.amount.toLong)
125140
statement.setString(2, e.paymentHash.toHex)
126141
statement.setString(3, p.fromChannelId.toHex)
127-
statement.setLong(4, p.timestamp)
142+
statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(p.timestamp)))
128143
statement.addBatch()
129144
})
130145
statement.executeBatch()
@@ -143,7 +158,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
143158
statement.setString(1, e.paymentHash.toHex)
144159
statement.setLong(2, nextTrampolineAmount.toLong)
145160
statement.setString(3, nextTrampolineNodeId.value.toHex)
146-
statement.setLong(4, e.timestamp)
161+
statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(e.timestamp)))
147162
statement.executeUpdate()
148163
}
149164
// trampoline relayed payments do MPP aggregation and may have M inputs and N outputs
@@ -156,7 +171,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
156171
statement.setString(3, p.channelId.toHex)
157172
statement.setString(4, p.direction)
158173
statement.setString(5, p.relayType)
159-
statement.setLong(6, e.timestamp)
174+
statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(e.timestamp)))
160175
statement.executeUpdate()
161176
}
162177
}
@@ -171,7 +186,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
171186
statement.setString(3, e.tx.txid.toHex)
172187
statement.setLong(4, e.fee.toLong)
173188
statement.setString(5, e.txType)
174-
statement.setLong(6, System.currentTimeMillis)
189+
statement.setTimestamp(6, Timestamp.from(Instant.now()))
175190
statement.executeUpdate()
176191
}
177192
}
@@ -189,17 +204,17 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
189204
statement.setString(3, errorName)
190205
statement.setString(4, errorMessage)
191206
statement.setBoolean(5, e.isFatal)
192-
statement.setLong(6, System.currentTimeMillis)
207+
statement.setTimestamp(6, Timestamp.from(Instant.now()))
193208
statement.executeUpdate()
194209
}
195210
}
196211
}
197212

198213
override def listSent(from: Long, to: Long): Seq[PaymentSent] =
199214
inTransaction { pg =>
200-
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement =>
201-
statement.setLong(1, from)
202-
statement.setLong(2, to)
215+
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement =>
216+
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
217+
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
203218
val rs = statement.executeQuery()
204219
var sentByParentId = Map.empty[UUID, PaymentSent]
205220
while (rs.next()) {
@@ -210,7 +225,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
210225
MilliSatoshi(rs.getLong("fees_msat")),
211226
rs.getByteVector32FromHex("to_channel_id"),
212227
None, // we don't store the route in the audit DB
213-
rs.getLong("timestamp"))
228+
rs.getTimestamp("timestamp").getTime)
214229
val sent = sentByParentId.get(parentId) match {
215230
case Some(s) => s.copy(parts = s.parts :+ part)
216231
case None => PaymentSent(
@@ -229,17 +244,17 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
229244

230245
override def listReceived(from: Long, to: Long): Seq[PaymentReceived] =
231246
inTransaction { pg =>
232-
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
233-
statement.setLong(1, from)
234-
statement.setLong(2, to)
247+
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement =>
248+
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
249+
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
235250
val rs = statement.executeQuery()
236251
var receivedByHash = Map.empty[ByteVector32, PaymentReceived]
237252
while (rs.next()) {
238253
val paymentHash = rs.getByteVector32FromHex("payment_hash")
239254
val part = PaymentReceived.PartialPayment(
240255
MilliSatoshi(rs.getLong("amount_msat")),
241256
rs.getByteVector32FromHex("from_channel_id"),
242-
rs.getLong("timestamp"))
257+
rs.getTimestamp("timestamp").getTime)
243258
val received = receivedByHash.get(paymentHash) match {
244259
case Some(r) => r.copy(parts = r.parts :+ part)
245260
case None => PaymentReceived(paymentHash, Seq(part))
@@ -253,9 +268,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
253268
override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] =
254269
inTransaction { pg =>
255270
var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]
256-
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp >= ? AND timestamp < ?")) { statement =>
257-
statement.setLong(1, from)
258-
statement.setLong(2, to)
271+
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement =>
272+
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
273+
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
259274
val rs = statement.executeQuery()
260275
while (rs.next()) {
261276
val paymentHash = rs.getByteVector32FromHex("payment_hash")
@@ -264,9 +279,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
264279
trampolineByHash += (paymentHash -> (amount, nodeId))
265280
}
266281
}
267-
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
268-
statement.setLong(1, from)
269-
statement.setLong(2, to)
282+
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement =>
283+
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
284+
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
270285
val rs = statement.executeQuery()
271286
var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]]
272287
while (rs.next()) {
@@ -276,7 +291,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
276291
MilliSatoshi(rs.getLong("amount_msat")),
277292
rs.getString("direction"),
278293
rs.getString("relay_type"),
279-
rs.getLong("timestamp"))
294+
rs.getTimestamp("timestamp").getTime)
280295
relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
281296
}
282297
relayedByHash.flatMap {
@@ -300,9 +315,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
300315

301316
override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] =
302317
inTransaction { pg =>
303-
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
304-
statement.setLong(1, from)
305-
statement.setLong(2, to)
318+
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement =>
319+
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
320+
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
306321
val rs = statement.executeQuery()
307322
var q: Queue[NetworkFee] = Queue()
308323
while (rs.next()) {
@@ -312,7 +327,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
312327
txId = rs.getByteVector32FromHex("tx_id"),
313328
fee = Satoshi(rs.getLong("fee_sat")),
314329
txType = rs.getString("tx_type"),
315-
timestamp = rs.getLong("timestamp"))
330+
timestamp = rs.getTimestamp("timestamp").getTime)
316331
}
317332
q
318333
}

0 commit comments

Comments
 (0)