diff --git a/.gitignore b/.gitignore index 936d62a073a8..a9a0804fe6f6 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ *.gcno *.dSYM *.rej +*.po *.pyc .cppcheck-suppress TAGS diff --git a/Makefile b/Makefile index fdd00ea8db30..9928e448555b 100644 --- a/Makefile +++ b/Makefile @@ -201,6 +201,11 @@ else LDLIBS = -L/usr/local/lib -lm -lgmp -lsqlite3 -lz $(COVFLAGS) endif +# If we have the postgres client library we need to link against it as well +ifeq ($(HAVE_POSTGRES),1) +LDLIBS += -lpq +endif + default: all-programs all-test-programs ccan/config.h: config.vars configure ccan/tools/configurator/configurator.c diff --git a/configure b/configure index 1d22c4704b78..85066278ca62 100755 --- a/configure +++ b/configure @@ -273,6 +273,20 @@ int main(void) return 0; } /*END*/ +var=HAVE_POSTGRES +desc=postgres +style=DEFINES_EVERYTHING|EXECUTE|MAY_NOT_COMPILE +link=-lpq +code= +#include +#include + +int main(void) +{ + printf("libpq version %d\n", PQlibVersion()); + return 0; +} +/*END*/ var=HAVE_GCC desc=compiler is GCC style=OUTSIDE_MAIN diff --git a/devtools/sql-rewrite.py b/devtools/sql-rewrite.py index 98d693bc2802..0e62b99d2163 100755 --- a/devtools/sql-rewrite.py +++ b/devtools/sql-rewrite.py @@ -2,16 +2,73 @@ from mako.template import Template +import re import sys -class Sqlite3Rewriter(object): - def rewrite(self, query): +DEBUG = False + + +def eprint(*args, **kwargs): + if not DEBUG: + return + print(*args, **kwargs, file=sys.stderr) + + +class Rewriter(object): + + def rewrite_types(self, query, mapping): + for old, new in mapping.items(): + query = re.sub(old, new, query) + return query + + def rewrite_single(self, query): + return query + + def rewrite(self, queries): + for i, q in enumerate(queries): + org = q['query'] + queries[i]['query'] = self.rewrite_single(org) + eprint("Rewritten statement\n\tfrom {}\n\t to {}".format(org, q['query'])) + return queries + + +class Sqlite3Rewriter(Rewriter): + def rewrite_single(self, query): + typemapping = { + r'BIGINT': 'INTEGER', + r'BIGINTEGER': 'INTEGER', + r'BIGSERIAL': 'INTEGER', + r'CURRENT_TIMESTAMP\(\)': "strftime('%s', 'now')", + r'INSERT INTO[ \t]+(.*)[ \t]+ON CONFLICT.*DO NOTHING;': 'INSERT OR IGNORE INTO \\1;', + } + return self.rewrite_types(query, typemapping) + + +class PostgresRewriter(Rewriter): + def rewrite_single(self, q): + # Let's start by replacing any eventual '?' placeholders + q2 = "" + count = 1 + for c in q: + if c == '?': + c = "${}".format(count) + count += 1 + q2 += c + query = q2 + + typemapping = { + r'BLOB': 'BYTEA', + r'CURRENT_TIMESTAMP\(\)': "EXTRACT(epoch FROM now())", + } + + query = self.rewrite_types(query, typemapping) return query rewriters = { "sqlite3": Sqlite3Rewriter(), + "postgres": PostgresRewriter(), } template = Template("""#ifndef LIGHTNINGD_WALLET_GEN_DB_${f.upper()} @@ -62,7 +119,6 @@ def chunk(pofile): queries = [] for c in chunk(pofile): - name = c[0][3:] # Skip other comments i = 1 @@ -73,7 +129,7 @@ def chunk(pofile): query = c[i][7:][:-1] queries.append({ - 'name': name, + 'name': query, 'query': query, 'placeholders': query.count('?'), 'readonly': "true" if query.upper().startswith("SELECT") else "false", diff --git a/doc/lightningd-config.5 b/doc/lightningd-config.5 index 9cdf50e09fb1..161c8cdba4b5 100644 --- a/doc/lightningd-config.5 +++ b/doc/lightningd-config.5 @@ -149,6 +149,12 @@ is a relative path, it is relative to the starting directory, not readable (we allow missing files in the default case)\. Using this inside a configuration file is meaningless\. + + \fBwallet\fR=\fIDSN\fR +Identify the location of the wallet\. This is a fully qualified data source +name, including a scheme such as \fBsqlite3\fR or \fBpostgres\fR followed by the +connection parameters\. + .SH Lightning node customization options \fBalias\fR=\fIRRGGBB\fR diff --git a/doc/lightningd-config.5.md b/doc/lightningd-config.5.md index 1b6a08ea6119..7b341fbe361a 100644 --- a/doc/lightningd-config.5.md +++ b/doc/lightningd-config.5.md @@ -128,6 +128,11 @@ is a relative path, it is relative to the starting directory, not readable (we allow missing files in the default case). Using this inside a configuration file is meaningless. + **wallet**=*DSN* +Identify the location of the wallet. This is a fully qualified data source +name, including a scheme such as `sqlite3` or `postgres` followed by the +connection parameters. + ### Lightning node customization options **alias**=*RRGGBB* diff --git a/lightningd/lightningd.h b/lightningd/lightningd.h index 04e08ceae016..06452c58f50b 100644 --- a/lightningd/lightningd.h +++ b/lightningd/lightningd.h @@ -228,6 +228,8 @@ struct lightningd { const char *original_directory; struct plugins *plugins; + + char *wallet_dsn; }; /* Turning this on allows a tal allocation to return NULL, rather than aborting. diff --git a/lightningd/options.c b/lightningd/options.c index 6136043882de..8b35ce8abfad 100644 --- a/lightningd/options.c +++ b/lightningd/options.c @@ -808,6 +808,11 @@ static void handle_minimal_config_opts(struct lightningd *ld, opt_ignore_talstr, opt_show_charp, &ld->config_dir, "Set working directory. All other files are relative to this"); + + ld->wallet_dsn = tal_fmt(ld, "sqlite3://%s/lightningd.sqlite3", ld->config_dir); + opt_register_early_arg("--wallet", opt_set_talstr, NULL, + &ld->wallet_dsn, + "Location of the wallet database."); } static void register_opts(struct lightningd *ld) diff --git a/requirements.txt b/requirements.txt index 5cca28d2b98e..5e137d42e09e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ -sqlparse==0.3.0 mako==1.0.14 mrkd==0.1.5 diff --git a/tests/db.py b/tests/db.py new file mode 100644 index 000000000000..bb153221d7df --- /dev/null +++ b/tests/db.py @@ -0,0 +1,197 @@ +from ephemeral_port_reserve import reserve +from glob import glob + +import logging +import os +import psycopg2 +import random +import re +import shutil +import signal +import sqlite3 +import string +import subprocess +import time + + +class Sqlite3Db(object): + def __init__(self, path): + self.path = path + + def get_dsn(self): + """SQLite3 doesn't provide a DSN, resulting in no CLI-option. + """ + return None + + def query(self, query): + orig = os.path.join(self.path) + copy = self.path + ".copy" + shutil.copyfile(orig, copy) + db = sqlite3.connect(copy) + + db.row_factory = sqlite3.Row + c = db.cursor() + c.execute(query) + rows = c.fetchall() + + result = [] + for row in rows: + result.append(dict(zip(row.keys(), row))) + + db.commit() + c.close() + db.close() + return result + + def execute(self, query): + db = sqlite3.connect(self.path) + c = db.cursor() + c.execute(query) + db.commit() + c.close() + db.close() + + +class PostgresDb(object): + def __init__(self, dbname, port): + self.dbname = dbname + self.port = port + + self.conn = psycopg2.connect("dbname={dbname} user=postgres host=localhost port={port}".format( + dbname=dbname, port=port + )) + cur = self.conn.cursor() + cur.execute('SELECT 1') + cur.close() + + def get_dsn(self): + return "postgres://postgres:password@localhost:{port}/{dbname}".format( + port=self.port, dbname=self.dbname + ) + + def query(self, query): + cur = self.conn.cursor() + cur.execute(query) + + # Collect the results into a list of dicts. + res = [] + for r in cur: + t = {} + # Zip the column definition with the value to get its name. + for c, v in zip(cur.description, r): + t[c.name] = v + res.append(t) + cur.close() + return res + + def execute(self, query): + with self.conn, self.conn.cursor() as cur: + cur.execute(query) + + +class SqliteDbProvider(object): + def __init__(self, directory): + self.directory = directory + + def start(self): + pass + + def get_db(self, node_directory, testname, node_id): + path = os.path.join( + node_directory, + 'lightningd.sqlite3' + ) + return Sqlite3Db(path) + + def stop(self): + pass + + +class PostgresDbProvider(object): + def __init__(self, directory): + self.directory = directory + self.port = None + self.proc = None + print("Starting PostgresDbProvider") + + def locate_path(self): + prefix = '/usr/lib/postgresql/*' + matches = glob(prefix) + + candidates = {} + for m in matches: + g = re.search(r'([0-9]+[\.0-9]*)', m) + if not g: + continue + candidates[float(g.group(1))] = m + + if len(candidates) == 0: + raise ValueError("Could not find `postgres` and `initdb` binaries in {}. Is postgresql installed?".format(prefix)) + + # Now iterate in reverse order through matches + for k, v in sorted(candidates.items())[::-1]: + initdb = os.path.join(v, 'bin', 'initdb') + postgres = os.path.join(v, 'bin', 'postgres') + if os.path.isfile(initdb) and os.path.isfile(postgres): + logging.info("Found `postgres` and `initdb` in {}".format(os.path.join(v, 'bin'))) + return initdb, postgres + + raise ValueError("Could not find `postgres` and `initdb` in any of the possible paths: {}".format(candidates.values())) + + def start(self): + passfile = os.path.join(self.directory, "pgpass.txt") + self.pgdir = os.path.join(self.directory, 'pgsql') + # Need to write a tiny file containing the password so `initdb` can pick it up + with open(passfile, 'w') as f: + f.write('cltest\n') + + initdb, postgres = self.locate_path() + subprocess.check_call([ + initdb, + '--pwfile={}'.format(passfile), + '--pgdata={}'.format(self.pgdir), + '--auth=trust', + '--username=postgres', + ]) + self.port = reserve() + self.proc = subprocess.Popen([ + postgres, + '-k', '/tmp/', # So we don't use /var/lib/... + '-D', self.pgdir, + '-p', str(self.port), + '-F', + '-i', + ]) + # Hacky but seems to work ok (might want to make the postgres proc a TailableProc as well if too flaky). + time.sleep(1) + self.conn = psycopg2.connect("dbname=template1 user=postgres host=localhost port={}".format(self.port)) + + # Required for CREATE DATABASE to work + self.conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + + def get_db(self, node_directory, testname, node_id): + # Random suffix to avoid collisions on repeated tests + nonce = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8)) + dbname = "{}_{}_{}".format(testname, node_id, nonce) + + cur = self.conn.cursor() + cur.execute("CREATE DATABASE {};".format(dbname)) + cur.close() + db = PostgresDb(dbname, self.port) + return db + + def stop(self): + # Send fast shutdown signal see [1] for details: + # + # SIGINT + # + # This is the Fast Shutdown mode. The server disallows new connections + # and sends all existing server processes SIGTERM, which will cause + # them to abort their current transactions and exit promptly. It then + # waits for all server processes to exit and finally shuts down. If + # the server is in online backup mode, backup mode will be terminated, + # rendering the backup useless. + # + # [1] https://www.postgresql.org/docs/9.1/server-shutdown.html + self.proc.send_signal(signal.SIGINT) + self.proc.wait() diff --git a/tests/fixtures.py b/tests/fixtures.py index f6c4456c6878..0b38b2de2570 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,4 +1,5 @@ from concurrent import futures +from db import SqliteDbProvider, PostgresDbProvider from utils import NodeFactory, BitcoinD import logging @@ -149,12 +150,13 @@ def teardown_checks(request): @pytest.fixture -def node_factory(request, directory, test_name, bitcoind, executor, teardown_checks): +def node_factory(request, directory, test_name, bitcoind, executor, db_provider, teardown_checks): nf = NodeFactory( test_name, bitcoind, executor, directory=directory, + db_provider=db_provider, ) yield nf @@ -275,6 +277,21 @@ def checkMemleak(node): return 0 +# Mapping from TEST_DB_PROVIDER env variable to class to be used +providers = { + 'sqlite3': SqliteDbProvider, + 'postgres': PostgresDbProvider, +} + + +@pytest.fixture(scope="session") +def db_provider(test_base_dir): + provider = providers[os.getenv('TEST_DB_PROVIDER', 'sqlite3')](test_base_dir) + provider.start() + yield provider + provider.stop() + + @pytest.fixture def executor(teardown_checks): ex = futures.ThreadPoolExecutor(max_workers=20) diff --git a/tests/requirements.txt b/tests/requirements.txt index 4208596b7127..bfa85f96305e 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -10,3 +10,4 @@ pytest-xdist==1.29.0 python-bitcoinlib==0.10.1 tqdm==4.32.2 pytest-timeout==1.3.3 +psycopg2==2.8.3 diff --git a/tests/test_closing.py b/tests/test_closing.py index 9b1e2ae7b1f6..13dbbbd19c1a 100644 --- a/tests/test_closing.py +++ b/tests/test_closing.py @@ -1550,7 +1550,7 @@ def test_option_upfront_shutdown_script(node_factory, bitcoind): wait_for(lambda: [c['state'] for c in only_one(l2.rpc.listpeers()['peers'])['channels']] == ['ONCHAIN', 'ONCHAIN']) # Figure out what address it will try to use. - keyidx = int(l1.db_query("SELECT val FROM vars WHERE name='bip32_max_index';")[0]['val']) + keyidx = int(l1.db_query("SELECT intval FROM vars WHERE name='bip32_max_index';")[0]['intval']) # Expect 1 for change address, 1 for the channel final address, # which are discarded as the 'scratch' tx that the fundchannel diff --git a/tests/test_connection.py b/tests/test_connection.py index a86cd3d6a0bd..c2f1e98800ed 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1583,8 +1583,9 @@ def test_no_fee_estimate(node_factory, bitcoind, executor): l1.rpc.connect(l2.info['id'], 'localhost', l2.port) l1.rpc.fundchannel(l2.info['id'], 10**6, 'slow') - # Can withdraw (use urgent feerate). - l1.rpc.withdraw(l2.rpc.newaddr()['bech32'], 'all', 'urgent') + # Can withdraw (use urgent feerate). `minconf` may be needed depending on + # the previous `fundchannel` selecting all confirmed outputs. + l1.rpc.withdraw(l2.rpc.newaddr()['bech32'], 'all', 'urgent', minconf=0) @unittest.skipIf(not DEVELOPER, "needs --dev-disconnect") @@ -1627,6 +1628,7 @@ def test_funder_simple_reconnect(node_factory, bitcoind): l1.pay(l2, 200000000) +@unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "sqlite3-specific DB rollback") @unittest.skipIf(not DEVELOPER, "needs LIGHTNINGD_DEV_LOG_IO") def test_dataloss_protection(node_factory, bitcoind): l1 = node_factory.get_node(may_reconnect=True, log_all_io=True, diff --git a/tests/test_db.py b/tests/test_db.py index 9358519c0eb6..8bf3495088f8 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,6 +1,7 @@ from fixtures import * # noqa: F401,F403 from utils import wait_for, sync_blockheight, COMPAT +import os import unittest @@ -116,6 +117,7 @@ def test_max_channel_id(node_factory, bitcoind): @unittest.skipIf(not COMPAT, "needs COMPAT to convert obsolete db") +@unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "This test is based on a sqlite3 snapshot") def test_scid_upgrade(node_factory): # Created through the power of sed "s/X'\([0-9]*\)78\([0-9]*\)78\([0-9]*\)'/X'\13A\23A\3'/" diff --git a/tests/test_misc.py b/tests/test_misc.py index db741caa964c..452ffebaefda 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -61,6 +61,7 @@ def test_names(node_factory): .format(key, alias, color)) +@unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "This migration is based on a sqlite3 snapshot") def test_db_upgrade(node_factory): l1 = node_factory.get_node() l1.stop() @@ -969,12 +970,11 @@ def test_reserve_enforcement(node_factory, executor): l2.stop() # They should both aim for 1%. - reserves = l2.db_query('SELECT channel_reserve_satoshis FROM channel_configs') + reserves = l2.db.query('SELECT channel_reserve_satoshis FROM channel_configs') assert reserves == [{'channel_reserve_satoshis': 10**6 // 100}] * 2 # Edit db to reduce reserve to 0 so it will try to violate it. - l2.db_query('UPDATE channel_configs SET channel_reserve_satoshis=0', - use_copy=False) + l2.db.execute('UPDATE channel_configs SET channel_reserve_satoshis=0') l2.start() wait_for(lambda: only_one(l2.rpc.listpeers(l1.info['id'])['peers'])['connected']) diff --git a/tests/test_pay.py b/tests/test_pay.py index 8319cea03b8c..1810c42ee516 100644 --- a/tests/test_pay.py +++ b/tests/test_pay.py @@ -1824,9 +1824,13 @@ def test_setchannelfee_usage(node_factory, bitcoind): l1.rpc.connect(l3.info['id'], 'localhost', l3.port) l1.fund_channel(l2, 1000000) + def channel_get_fees(scid): + return l1.db.query( + 'SELECT feerate_base, feerate_ppm FROM channels ' + 'WHERE short_channel_id=\'{}\';'.format(scid)) + # get short channel id scid = l1.get_channel_scid(l2) - scid_hex = scid.encode('utf-8').hex() # feerates should be init with global config db_fees = l1.db_query('SELECT feerate_base, feerate_ppm FROM channels;') @@ -1845,9 +1849,7 @@ def test_setchannelfee_usage(node_factory, bitcoind): assert(result['channels'][0]['short_channel_id'] == scid) # check if custom values made it into the database - db_fees = l1.db_query( - 'SELECT feerate_base, feerate_ppm FROM channels ' - 'WHERE hex(short_channel_id)="' + scid_hex + '";') + db_fees = channel_get_fees(scid) assert(db_fees[0]['feerate_base'] == 1337) assert(db_fees[0]['feerate_ppm'] == 137) @@ -1878,9 +1880,7 @@ def test_setchannelfee_usage(node_factory, bitcoind): result = l1.rpc.setchannelfee(scid, 0, 0) assert(result['base'] == 0) assert(result['ppm'] == 0) - db_fees = l1.db_query( - 'SELECT feerate_base, feerate_ppm FROM channels ' - 'WHERE hex(short_channel_id)="' + scid_hex + '";') + db_fees = channel_get_fees(scid) assert(db_fees[0]['feerate_base'] == 0) assert(db_fees[0]['feerate_ppm'] == 0) @@ -1889,9 +1889,7 @@ def test_setchannelfee_usage(node_factory, bitcoind): assert(result['base'] == DEF_BASE) assert(result['ppm'] == DEF_PPM) # check default values in DB - db_fees = l1.db_query( - 'SELECT feerate_base, feerate_ppm FROM channels ' - 'WHERE hex(short_channel_id)="' + scid_hex + '";') + db_fees = channel_get_fees(scid) assert(db_fees[0]['feerate_base'] == DEF_BASE) assert(db_fees[0]['feerate_ppm'] == DEF_PPM) @@ -1902,9 +1900,7 @@ def test_setchannelfee_usage(node_factory, bitcoind): assert(len(result['channels']) == 1) assert(result['channels'][0]['peer_id'] == l2.info['id']) assert(result['channels'][0]['short_channel_id'] == scid) - db_fees = l1.db_query( - 'SELECT feerate_base, feerate_ppm FROM channels ' - 'WHERE hex(short_channel_id)="' + scid_hex + '";') + db_fees = channel_get_fees(scid) assert(db_fees[0]['feerate_base'] == 42) assert(db_fees[0]['feerate_ppm'] == 43) @@ -1917,9 +1913,7 @@ def test_setchannelfee_usage(node_factory, bitcoind): # check if 'base' unit can be modified to satoshi result = l1.rpc.setchannelfee(scid, '1sat') assert(result['base'] == 1000) - db_fees = l1.db_query( - 'SELECT feerate_base, feerate_ppm FROM channels ' - 'WHERE hex(short_channel_id)="' + scid_hex + '";') + db_fees = channel_get_fees(scid) assert(db_fees[0]['feerate_base'] == 1000) # check if 'ppm' values greater than u32_max fail diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 1e253ef911b9..6f2b85b18de7 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -290,6 +290,7 @@ def test_async_rpcmethod(node_factory, executor): assert [r.result() for r in results] == [42] * len(results) +@unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "Only sqlite3 implements the db_write_hook currently") def test_db_hook(node_factory, executor): """This tests the db hook.""" dbfile = os.path.join(node_factory.directory, "dblog.sqlite3") diff --git a/tests/utils.py b/tests/utils.py index 92be43f5e645..b9f946add264 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -462,7 +462,9 @@ def wait(self, timeout=10): class LightningNode(object): - def __init__(self, daemon, rpc, btc, executor, may_fail=False, may_reconnect=False, allow_broken_log=False, allow_bad_gossip=False): + def __init__(self, daemon, rpc, btc, executor, may_fail=False, + may_reconnect=False, allow_broken_log=False, + allow_bad_gossip=False, db=None): self.rpc = rpc self.daemon = daemon self.bitcoin = btc @@ -471,6 +473,7 @@ def __init__(self, daemon, rpc, btc, executor, may_fail=False, may_reconnect=Fal self.may_reconnect = may_reconnect self.allow_broken_log = allow_broken_log self.allow_bad_gossip = allow_bad_gossip + self.db = db def connect(self, remote_node): self.rpc.connect(remote_node.info['id'], '127.0.0.1', remote_node.daemon.port) @@ -510,28 +513,8 @@ def fundwallet(self, sats, addrtype="p2sh-segwit"): def getactivechannels(self): return [c for c in self.rpc.listchannels()['channels'] if c['active']] - def db_query(self, query, use_copy=True): - orig = os.path.join(self.daemon.lightning_dir, "lightningd.sqlite3") - if use_copy: - copy = os.path.join(self.daemon.lightning_dir, "lightningd-copy.sqlite3") - shutil.copyfile(orig, copy) - db = sqlite3.connect(copy) - else: - db = sqlite3.connect(orig) - - db.row_factory = sqlite3.Row - c = db.cursor() - c.execute(query) - rows = c.fetchall() - - result = [] - for row in rows: - result.append(dict(zip(row.keys(), row))) - - db.commit() - c.close() - db.close() - return result + def db_query(self, query): + return self.db.query(query) # Assumes node is stopped! def db_manip(self, query): @@ -771,7 +754,7 @@ def wait_for_onchaind_broadcast(self, name, resolve=None): class NodeFactory(object): """A factory to setup and start `lightningd` daemons. """ - def __init__(self, testname, bitcoind, executor, directory): + def __init__(self, testname, bitcoind, executor, directory, db_provider): self.testname = testname self.next_id = 1 self.nodes = [] @@ -779,6 +762,7 @@ def __init__(self, testname, bitcoind, executor, directory): self.bitcoind = bitcoind self.directory = directory self.lock = threading.Lock() + self.db_provider = db_provider def split_options(self, opts): """Split node options from cli options @@ -880,11 +864,17 @@ def get_node(self, disconnect=None, options=None, may_fail=False, if options is not None: daemon.opts.update(options) + # Get the DB backend DSN we should be using for this test and this node. + db = self.db_provider.get_db(lightning_dir, self.testname, node_id) + dsn = db.get_dsn() + if dsn is not None: + daemon.opts['wallet'] = dsn + rpc = LightningRpc(socket_path, self.executor) node = LightningNode(daemon, rpc, self.bitcoind, self.executor, may_fail=may_fail, may_reconnect=may_reconnect, allow_broken_log=allow_broken_log, - allow_bad_gossip=allow_bad_gossip) + allow_bad_gossip=allow_bad_gossip, db=db) # Regtest estimatefee are unusable, so override. node.set_feerates(feerates, False) diff --git a/wallet/Makefile b/wallet/Makefile index 53993c5b4058..2529a38889b4 100644 --- a/wallet/Makefile +++ b/wallet/Makefile @@ -12,6 +12,7 @@ WALLET_LIB_SRC := \ wallet/walletrpc.c WALLET_DB_DRIVERS := \ + wallet/db_postgres.c \ wallet/db_sqlite3.c WALLET_LIB_OBJS := $(WALLET_LIB_SRC:.c=.o) $(WALLET_DB_DRIVERS:.c=.o) @@ -29,7 +30,9 @@ check-source-bolt: $(WALLET_LIB_SRC:%=bolt-check/%) $(WALLET_LIB_HEADERS:%=bolt- clean: wallet-clean +# Each database driver depends on its rewritten statements. wallet/db_sqlite3.c: wallet/gen_db_sqlite3.c +wallet/db_postgres.c: wallet/gen_db_postgres.c # The following files contain SQL-annotated statements that we need to extact SQL_FILES := \ @@ -42,12 +45,13 @@ SQL_FILES := \ wallet/statements.po: $(SQL_FILES) xgettext -kNAMED_SQL -kSQL --add-location --no-wrap --omit-header -o $@ $(SQL_FILES) -wallet/gen_db_sqlite3.c: wallet/statements.po devtools/sql-rewrite.py - devtools/sql-rewrite.py wallet/statements.po sqlite3 > wallet/gen_db_sqlite3.c +wallet/gen_db_%.c: wallet/statements.po devtools/sql-rewrite.py + devtools/sql-rewrite.py wallet/statements.po $* > wallet/gen_db_$*.c wallet-clean: $(RM) $(WALLET_LIB_OBJS) $(RM) wallet/statements.po $(RM) wallet/gen_db_sqlite3.c + $(RM) wallet/gen_db_postgres.c include wallet/test/Makefile diff --git a/wallet/db.c b/wallet/db.c index 2e741ab1f341..6176af2e27e9 100644 --- a/wallet/db.c +++ b/wallet/db.c @@ -10,7 +10,6 @@ #include #include -#define DB_FILE "lightningd.sqlite3" #define NSEC_IN_SEC 1000000000 struct migration { @@ -27,9 +26,9 @@ static struct migration dbmigrations[] = { {SQL("CREATE TABLE version (version INTEGER)"), NULL}, {SQL("INSERT INTO version VALUES (1)"), NULL}, {SQL("CREATE TABLE outputs (" - " prev_out_tx CHAR(64)" + " prev_out_tx BLOB" ", prev_out_index INTEGER" - ", value INTEGER" + ", value BIGINT" ", type INTEGER" ", status INTEGER" ", keyindex INTEGER" @@ -42,39 +41,46 @@ static struct migration dbmigrations[] = { ");"), NULL}, {SQL("CREATE TABLE shachains (" - " id INTEGER" - ", min_index INTEGER" - ", num_valid INTEGER" + " id BIGSERIAL" + ", min_index BIGINT" + ", num_valid BIGINT" ", PRIMARY KEY (id)" ");"), NULL}, {SQL("CREATE TABLE shachain_known (" - " shachain_id INTEGER REFERENCES shachains(id) ON DELETE CASCADE" + " shachain_id BIGINT REFERENCES shachains(id) ON DELETE CASCADE" ", pos INTEGER" - ", idx INTEGER" + ", idx BIGINT" ", hash BLOB" ", PRIMARY KEY (shachain_id, pos)" ");"), NULL}, + {SQL("CREATE TABLE peers (" + " id BIGSERIAL" + ", node_id BLOB UNIQUE" /* pubkey */ + ", address TEXT" + ", PRIMARY KEY (id)" + ");"), + NULL}, {SQL("CREATE TABLE channels (" - " id INTEGER," /* chan->id */ - " peer_id INTEGER REFERENCES peers(id) ON DELETE CASCADE," - " short_channel_id BLOB," - " channel_config_local INTEGER," - " channel_config_remote INTEGER," + " id BIGSERIAL," /* chan->id */ + " peer_id BIGINT REFERENCES peers(id) ON DELETE CASCADE," + " short_channel_id TEXT," + " channel_config_local BIGINT," + " channel_config_remote BIGINT," " state INTEGER," " funder INTEGER," " channel_flags INTEGER," " minimum_depth INTEGER," - " next_index_local INTEGER," - " next_index_remote INTEGER," - " next_htlc_id INTEGER, " + " next_index_local BIGINT," + " next_index_remote BIGINT," + " next_htlc_id BIGINT," " funding_tx_id BLOB," " funding_tx_outnum INTEGER," - " funding_satoshi INTEGER," + " funding_satoshi BIGINT," " funding_locked_remote INTEGER," - " push_msatoshi INTEGER," - " msatoshi_local INTEGER," /* our_msatoshi */ + " push_msatoshi BIGINT," + " msatoshi_local BIGINT," /* our_msatoshi */ /* START channel_info */ " fundingkey_remote BLOB," " revocation_basepoint_remote BLOB," @@ -86,10 +92,10 @@ static struct migration dbmigrations[] = { " local_feerate_per_kw INTEGER," " remote_feerate_per_kw INTEGER," /* END channel_info */ - " shachain_remote_id INTEGER," + " shachain_remote_id BIGINT," " shutdown_scriptpubkey_remote BLOB," - " shutdown_keyidx_local INTEGER," - " last_sent_commit_state INTEGER," + " shutdown_keyidx_local BIGINT," + " last_sent_commit_state BIGINT," " last_sent_commit_id INTEGER," " last_tx BLOB," " last_sig BLOB," @@ -98,31 +104,24 @@ static struct migration dbmigrations[] = { " PRIMARY KEY (id)" ");"), NULL}, - {SQL("CREATE TABLE peers (" - " id INTEGER" - ", node_id BLOB UNIQUE" /* pubkey */ - ", address TEXT" - ", PRIMARY KEY (id)" - ");"), - NULL}, {SQL("CREATE TABLE channel_configs (" - " id INTEGER," - " dust_limit_satoshis INTEGER," - " max_htlc_value_in_flight_msat INTEGER," - " channel_reserve_satoshis INTEGER," - " htlc_minimum_msat INTEGER," + " id BIGSERIAL," + " dust_limit_satoshis BIGINT," + " max_htlc_value_in_flight_msat BIGINT," + " channel_reserve_satoshis BIGINT," + " htlc_minimum_msat BIGINT," " to_self_delay INTEGER," " max_accepted_htlcs INTEGER," " PRIMARY KEY (id)" ");"), NULL}, {SQL("CREATE TABLE channel_htlcs (" - " id INTEGER," - " channel_id INTEGER REFERENCES channels(id) ON DELETE CASCADE," - " channel_htlc_id INTEGER," + " id BIGSERIAL," + " channel_id BIGINT REFERENCES channels(id) ON DELETE CASCADE," + " channel_htlc_id BIGINT," " direction INTEGER," - " origin_htlc INTEGER," - " msatoshi INTEGER," + " origin_htlc BIGINT," + " msatoshi BIGINT," " cltv_expiry INTEGER," " payment_hash BLOB," " payment_key BLOB," @@ -136,9 +135,9 @@ static struct migration dbmigrations[] = { ");"), NULL}, {SQL("CREATE TABLE invoices (" - " id INTEGER," + " id BIGSERIAL," " state INTEGER," - " msatoshi INTEGER," + " msatoshi BIGINT," " payment_hash BLOB," " payment_key BLOB," " label TEXT," @@ -148,28 +147,28 @@ static struct migration dbmigrations[] = { ");"), NULL}, {SQL("CREATE TABLE payments (" - " id INTEGER," + " id BIGSERIAL," " timestamp INTEGER," " status INTEGER," " payment_hash BLOB," " direction INTEGER," " destination BLOB," - " msatoshi INTEGER," + " msatoshi BIGINT," " PRIMARY KEY (id)," " UNIQUE (payment_hash)" ");"), NULL}, /* Add expiry field to invoices (effectively infinite). */ - {SQL("ALTER TABLE invoices ADD expiry_time INTEGER;"), NULL}, + {SQL("ALTER TABLE invoices ADD expiry_time BIGINT;"), NULL}, {SQL("UPDATE invoices SET expiry_time=9223372036854775807;"), NULL}, /* Add pay_index field to paid invoices (initially, same order as id). */ - {SQL("ALTER TABLE invoices ADD pay_index INTEGER;"), NULL}, + {SQL("ALTER TABLE invoices ADD pay_index BIGINT;"), NULL}, {SQL("CREATE UNIQUE INDEX invoices_pay_index ON invoices(pay_index);"), NULL}, {SQL("UPDATE invoices SET pay_index=id WHERE state=1;"), NULL}, /* only paid invoice */ /* Create next_pay_index variable (highest pay_index). */ - {SQL("INSERT OR REPLACE INTO vars(name, val)" + {SQL("INSERT INTO vars(name, val)" " VALUES('next_pay_index', " " COALESCE((SELECT MAX(pay_index) FROM invoices WHERE state=1), 0) " "+ 1" @@ -178,14 +177,13 @@ static struct migration dbmigrations[] = { /* Create first_block field; initialize from channel id if any. * This fails for channels still awaiting lockin, but that only applies to * pre-release software, so it's forgivable. */ - {SQL("ALTER TABLE channels ADD first_blocknum INTEGER;"), NULL}, - {SQL("UPDATE channels SET first_blocknum=CAST(short_channel_id AS INTEGER) " - "WHERE short_channel_id IS NOT NULL;"), + {SQL("ALTER TABLE channels ADD first_blocknum BIGINT;"), NULL}, + {SQL("UPDATE channels SET first_blocknum=1 WHERE short_channel_id IS NOT NULL;"), NULL}, - {SQL("ALTER TABLE outputs ADD COLUMN channel_id INTEGER;"), NULL}, + {SQL("ALTER TABLE outputs ADD COLUMN channel_id BIGINT;"), NULL}, {SQL("ALTER TABLE outputs ADD COLUMN peer_id BLOB;"), NULL}, {SQL("ALTER TABLE outputs ADD COLUMN commitment_point BLOB;"), NULL}, - {SQL("ALTER TABLE invoices ADD COLUMN msatoshi_received INTEGER;"), NULL}, + {SQL("ALTER TABLE invoices ADD COLUMN msatoshi_received BIGINT;"), NULL}, /* Normally impossible, so at least we'll know if databases are ancient. */ {SQL("UPDATE invoices SET msatoshi_received=0 WHERE state=1;"), NULL}, {SQL("ALTER TABLE channels ADD COLUMN last_was_revoke INTEGER;"), NULL}, @@ -194,12 +192,12 @@ static struct migration dbmigrations[] = { * rename & copy, which works because there are no triggers etc. */ {SQL("ALTER TABLE payments RENAME TO temp_payments;"), NULL}, {SQL("CREATE TABLE payments (" - " id INTEGER," + " id BIGSERIAL," " timestamp INTEGER," " status INTEGER," " payment_hash BLOB," " destination BLOB," - " msatoshi INTEGER," + " msatoshi BIGINT," " PRIMARY KEY (id)," " UNIQUE (payment_hash)" ");"), @@ -214,9 +212,9 @@ static struct migration dbmigrations[] = { {SQL("ALTER TABLE payments ADD COLUMN path_secrets BLOB;"), NULL}, /* Create time-of-payment of invoice, default already-paid * invoices to current time. */ - {SQL("ALTER TABLE invoices ADD paid_timestamp INTEGER;"), NULL}, + {SQL("ALTER TABLE invoices ADD paid_timestamp BIGINT;"), NULL}, {SQL("UPDATE invoices" - " SET paid_timestamp = strftime('%s', 'now')" + " SET paid_timestamp = CURRENT_TIMESTAMP()" " WHERE state = 1;"), NULL}, /* We need to keep the route node pubkeys and short channel ids to @@ -224,7 +222,7 @@ static struct migration dbmigrations[] = { * because we cannot safely save them as blobs due to byteorder * concerns. */ {SQL("ALTER TABLE payments ADD COLUMN route_nodes BLOB;"), NULL}, - {SQL("ALTER TABLE payments ADD COLUMN route_channels TEXT;"), NULL}, + {SQL("ALTER TABLE payments ADD COLUMN route_channels BLOB;"), NULL}, {SQL("CREATE TABLE htlc_sigs (channelid INTEGER REFERENCES channels(id) ON " "DELETE CASCADE, signature BLOB);"), NULL}, @@ -297,7 +295,7 @@ static struct migration dbmigrations[] = { NULL}, /* erring_index */ {SQL("ALTER TABLE payments ADD failcode INTEGER;"), NULL}, /* failcode */ {SQL("ALTER TABLE payments ADD failnode BLOB;"), NULL}, /* erring_node */ - {SQL("ALTER TABLE payments ADD failchannel BLOB;"), + {SQL("ALTER TABLE payments ADD failchannel TEXT;"), NULL}, /* erring_channel */ {SQL("ALTER TABLE payments ADD failupdate BLOB;"), NULL}, /* channel_update - can be NULL*/ @@ -310,14 +308,14 @@ static struct migration dbmigrations[] = { " WHERE status <> 0;"), NULL}, /* PAYMENT_PENDING */ /* -- Routing statistics -- */ - {SQL("ALTER TABLE channels ADD in_payments_offered INTEGER;"), NULL}, - {SQL("ALTER TABLE channels ADD in_payments_fulfilled INTEGER;"), NULL}, - {SQL("ALTER TABLE channels ADD in_msatoshi_offered INTEGER;"), NULL}, - {SQL("ALTER TABLE channels ADD in_msatoshi_fulfilled INTEGER;"), NULL}, - {SQL("ALTER TABLE channels ADD out_payments_offered INTEGER;"), NULL}, - {SQL("ALTER TABLE channels ADD out_payments_fulfilled INTEGER;"), NULL}, - {SQL("ALTER TABLE channels ADD out_msatoshi_offered INTEGER;"), NULL}, - {SQL("ALTER TABLE channels ADD out_msatoshi_fulfilled INTEGER;"), NULL}, + {SQL("ALTER TABLE channels ADD in_payments_offered INTEGER DEFAULT 0;"), NULL}, + {SQL("ALTER TABLE channels ADD in_payments_fulfilled INTEGER DEFAULT 0;"), NULL}, + {SQL("ALTER TABLE channels ADD in_msatoshi_offered BIGINT DEFAULT 0;"), NULL}, + {SQL("ALTER TABLE channels ADD in_msatoshi_fulfilled BIGINT DEFAULT 0;"), NULL}, + {SQL("ALTER TABLE channels ADD out_payments_offered INTEGER DEFAULT 0;"), NULL}, + {SQL("ALTER TABLE channels ADD out_payments_fulfilled INTEGER DEFAULT 0;"), NULL}, + {SQL("ALTER TABLE channels ADD out_msatoshi_offered BIGINT DEFAULT 0;"), NULL}, + {SQL("ALTER TABLE channels ADD out_msatoshi_fulfilled BIGINT DEFAULT 0;"), NULL}, {SQL("UPDATE channels" " SET in_payments_offered = 0, in_payments_fulfilled = 0" " , in_msatoshi_offered = 0, in_msatoshi_fulfilled = 0" @@ -327,12 +325,12 @@ static struct migration dbmigrations[] = { NULL}, /* -- Routing statistics ends --*/ /* Record the msatoshi actually sent in a payment. */ - {SQL("ALTER TABLE payments ADD msatoshi_sent INTEGER;"), NULL}, + {SQL("ALTER TABLE payments ADD msatoshi_sent BIGINT;"), NULL}, {SQL("UPDATE payments SET msatoshi_sent = msatoshi;"), NULL}, /* Delete dangling utxoset entries due to Issue #1280 */ {SQL("DELETE FROM utxoset WHERE blockheight IN (" " SELECT DISTINCT(blockheight)" - " FROM utxoset LEFT OUTER JOIN blocks on (blockheight == " + " FROM utxoset LEFT OUTER JOIN blocks on (blockheight = " "blocks.height) " " WHERE blocks.hash IS NULL" ");"), @@ -346,8 +344,8 @@ static struct migration dbmigrations[] = { "max_possible_feerate=250000;"), NULL}, /* -- Min and max msatoshi_to_us -- */ - {SQL("ALTER TABLE channels ADD msatoshi_to_us_min INTEGER;"), NULL}, - {SQL("ALTER TABLE channels ADD msatoshi_to_us_max INTEGER;"), NULL}, + {SQL("ALTER TABLE channels ADD msatoshi_to_us_min BIGINT;"), NULL}, + {SQL("ALTER TABLE channels ADD msatoshi_to_us_max BIGINT;"), NULL}, {SQL("UPDATE channels" " SET msatoshi_to_us_min = msatoshi_local" " , msatoshi_to_us_max = msatoshi_local" @@ -374,8 +372,8 @@ static struct migration dbmigrations[] = { /* -- Detailed payment faiure ends -- */ {SQL("CREATE TABLE channeltxs (" /* The id serves as insertion order and short ID */ - " id INTEGER" - ", channel_id INTEGER REFERENCES channels(id) ON DELETE CASCADE" + " id BIGSERIAL" + ", channel_id BIGINT REFERENCES channels(id) ON DELETE CASCADE" ", type INTEGER" ", transaction_id BLOB REFERENCES transactions(id) ON DELETE CASCADE" /* The input_num is only used by the txo_watch, 0 if txwatch */ @@ -394,8 +392,9 @@ static struct migration dbmigrations[] = { /* Now make sure we have the lower bound block with the first_blocknum * height. This may introduce a block with NULL height if we didn't have any * blocks, remove that in the next. */ - {SQL("INSERT OR IGNORE INTO blocks (height) VALUES ((SELECT " - "MIN(first_blocknum) FROM channels));"), + {SQL("INSERT INTO blocks (height) VALUES ((SELECT " + "MIN(first_blocknum) FROM channels)) " + "ON CONFLICT(height) DO NOTHING;"), NULL}, {SQL("DELETE FROM blocks WHERE height IS NULL;"), NULL}, /* -- End of PR #1398 -- */ @@ -411,12 +410,12 @@ static struct migration dbmigrations[] = { * deleted when the HTLC entries or the channel entries are * deleted to avoid unexpected drops in statistics. */ {SQL("CREATE TABLE forwarded_payments (" - " in_htlc_id INTEGER REFERENCES channel_htlcs(id) ON DELETE SET NULL" - ", out_htlc_id INTEGER REFERENCES channel_htlcs(id) ON DELETE SET NULL" - ", in_channel_scid INTEGER" - ", out_channel_scid INTEGER" - ", in_msatoshi INTEGER" - ", out_msatoshi INTEGER" + " in_htlc_id BIGINT REFERENCES channel_htlcs(id) ON DELETE SET NULL" + ", out_htlc_id BIGINT REFERENCES channel_htlcs(id) ON DELETE SET NULL" + ", in_channel_scid BIGINT" + ", out_channel_scid BIGINT" + ", in_msatoshi BIGINT" + ", out_msatoshi BIGINT" ", state INTEGER" ", UNIQUE(in_htlc_id, out_htlc_id)" ");"), @@ -434,9 +433,9 @@ static struct migration dbmigrations[] = { {SQL("ALTER TABLE channels ADD feerate_base INTEGER;"), NULL}, {SQL("ALTER TABLE channels ADD feerate_ppm INTEGER;"), NULL}, {NULL, migrate_pr2342_feerate_per_channel}, - {SQL("ALTER TABLE channel_htlcs ADD received_time INTEGER"), NULL}, - {SQL("ALTER TABLE forwarded_payments ADD received_time INTEGER"), NULL}, - {SQL("ALTER TABLE forwarded_payments ADD resolved_time INTEGER"), NULL}, + {SQL("ALTER TABLE channel_htlcs ADD received_time BIGINT"), NULL}, + {SQL("ALTER TABLE forwarded_payments ADD received_time BIGINT"), NULL}, + {SQL("ALTER TABLE forwarded_payments ADD resolved_time BIGINT"), NULL}, {SQL("ALTER TABLE channels ADD remote_upfront_shutdown_script BLOB;"), NULL}, /* PR #2524: Add failcode into forward_payment */ @@ -445,12 +444,12 @@ static struct migration dbmigrations[] = { {SQL("ALTER TABLE channels ADD remote_ann_node_sig BLOB;"), NULL}, {SQL("ALTER TABLE channels ADD remote_ann_bitcoin_sig BLOB;"), NULL}, /* Additional information for transaction tracking and listing */ - {SQL("ALTER TABLE transactions ADD type INTEGER;"), NULL}, + {SQL("ALTER TABLE transactions ADD type BIGINT;"), NULL}, /* Not a foreign key on purpose since we still delete channels from * the DB which would remove this. It is mainly used to group payments * in the list view anyway, e.g., show all close and htlc transactions * as a single bundle. */ - {SQL("ALTER TABLE transactions ADD channel_id INTEGER;"), NULL}, + {SQL("ALTER TABLE transactions ADD channel_id BIGINT;"), NULL}, /* Convert pre-Adelaide short_channel_ids */ {SQL("UPDATE channels" " SET short_channel_id = REPLACE(short_channel_id, ':', 'x')" @@ -458,8 +457,12 @@ static struct migration dbmigrations[] = { {SQL("UPDATE payments SET failchannel = REPLACE(failchannel, ':', 'x')" " WHERE failchannel IS NOT NULL;"), NULL }, /* option_static_remotekey is nailed at creation time. */ - {SQL("ALTER TABLE channels ADD COLUMN option_static_remotekey" - " DEFAULT FALSE;"), NULL }, + {SQL("ALTER TABLE channels ADD COLUMN option_static_remotekey INTEGER" + " DEFAULT 0;"), NULL }, + {SQL("ALTER TABLE vars ADD COLUMN intval INTEGER"), NULL}, + {SQL("ALTER TABLE vars ADD COLUMN blobval BLOB"), NULL}, + {SQL("UPDATE vars SET intval = CAST(val AS INTEGER) WHERE name IN ('bip32_max_index', 'last_processed_block', 'next_pay_index')"), NULL}, + {SQL("UPDATE vars SET blobval = CAST(val AS BLOB) WHERE name = 'genesis_hash'"), NULL}, }; /* Leak tracking. */ @@ -480,6 +483,9 @@ static void db_assert_no_outstanding_statements(struct db *db) static void db_stmt_free(struct db_stmt *stmt) { + if (!stmt->executed) + fatal("Freeing an un-executed statement from %s: %s", + stmt->location, stmt->query->query); if (stmt->inner_stmt) stmt->db->config->stmt_free_fn(stmt); assert(stmt->inner_stmt == NULL); @@ -503,7 +509,7 @@ struct db_stmt *db_prepare_v2_(const char *location, struct db *db, /* Look up the query by its ID */ for (size_t i = 0; i < db->config->num_queries; i++) { - if (streq(query_id, db->config->queries[i].query)) { + if (streq(query_id, db->config->queries[i].name)) { stmt->query = &db->config->queries[i]; break; } @@ -572,12 +578,14 @@ const unsigned char *db_column_text(struct db_stmt *stmt, int col) size_t db_count_changes(struct db_stmt *stmt) { + assert(stmt->executed); return stmt->db->config->count_changes_fn(stmt); } u64 db_last_insert_id_v2(struct db_stmt *stmt TAKES) { u64 id; + assert(stmt->executed); id = stmt->db->config->last_insert_id_fn(stmt); if (taken(stmt)) @@ -645,24 +653,26 @@ void db_commit_transaction(struct db *db) db->in_transaction = NULL; } -static void setup_open_db(struct db *db) -{ - /* This must be outside a transaction, so catch it */ - assert(!db->in_transaction); - - db_prepare_for_changes(db); - if (db->config->setup_fn) - db->config->setup_fn(db); - db_report_changes(db, NULL, 0); -} - -static struct db_config *db_config_find(const char *driver_name) +static struct db_config *db_config_find(const char *dsn) { size_t num_configs; struct db_config **configs = autodata_get(db_backends, &num_configs); - for (size_t i=0; iname)) + const char *sep, *driver_name; + sep = strstr(dsn, "://"); + + if (!sep) + db_fatal("%s doesn't look like a valid data-source name (missing \"://\" separator.", dsn); + + driver_name = tal_strndup(tmpctx, dsn, sep - dsn); + + for (size_t i=0; iname)) { + tal_free(driver_name); return configs[i]; + } + } + + tal_free(driver_name); return NULL; } @@ -671,38 +681,29 @@ static struct db_config *db_config_find(const char *driver_name) */ static struct db *db_open(const tal_t *ctx, char *filename) { - int err; struct db *db; - sqlite3 *sql; - const char *driver_name = "sqlite3"; - - int flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE; - err = sqlite3_open_v2(filename, &sql, flags, NULL); - - if (err != SQLITE_OK) { - db_fatal("failed to open database %s: %s", filename, - sqlite3_errstr(err)); - } db = tal(ctx, struct db); db->filename = tal_strdup(db, filename); - db->sql = sql; - db->config = NULL; list_head_init(&db->pending_statements); + if (!strstr(db->filename, "://")) + db_fatal("Could not extract driver name from \"%s\"", db->filename); - db->config = db_config_find(driver_name); + db->config = db_config_find(db->filename); if (!db->config) - db_fatal("Unable to find DB driver for %s", driver_name); - - // FIXME(cdecker) Once we parse DB connection strings this needs to be - // instantiated correctly. - db->conn = sql; + db_fatal("Unable to find DB driver for %s", db->filename); tal_add_destructor(db, destroy_db); db->in_transaction = NULL; db->changes = NULL; - setup_open_db(db); + /* This must be outside a transaction, so catch it */ + assert(!db->in_transaction); + + db_prepare_for_changes(db); + if (db->config->setup_fn && !db->config->setup_fn(db)) + fatal("Error calling DB setup: %s", db->error); + db_report_changes(db, NULL, 0); return db; } @@ -719,13 +720,23 @@ static int db_get_version(struct db *db) { int res = -1; struct db_stmt *stmt = db_prepare_v2(db, SQL("SELECT version FROM version LIMIT 1")); + + /* + * Tentatively execute a query, but allow failures. Some databases + * like postgres will terminate the DB transaction if there is an + * error during the execution of a query, e.g., trying to access a + * table that doesn't exist yet, so we need to terminate and restart + * the DB transaction. + */ if (!db_query_prepared(stmt)) { + db_commit_transaction(stmt->db); + db_begin_transaction(stmt->db); tal_free(stmt); return res; } if (db_step(stmt)) - res = db_column_u64(stmt, 0); + res = db_column_int(stmt, 0); tal_free(stmt); return res; @@ -768,7 +779,7 @@ static void db_migrate(struct lightningd *ld, struct db *db, struct log *log) /* Finally update the version number in the version table */ stmt = db_prepare_v2(db, SQL("UPDATE version SET version=?;")); - db_bind_u64(stmt, 0, available); + db_bind_int(stmt, 0, available); db_exec_prepared_v2(stmt); tal_free(stmt); @@ -776,7 +787,7 @@ static void db_migrate(struct lightningd *ld, struct db *db, struct log *log) if (current != orig) { stmt = db_prepare_v2( db, SQL("INSERT INTO db_upgrades VALUES (?, ?);")); - db_bind_u64(stmt, 0, orig); + db_bind_int(stmt, 0, orig); db_bind_text(stmt, 1, version()); db_exec_prepared_v2(stmt); tal_free(stmt); @@ -787,8 +798,7 @@ static void db_migrate(struct lightningd *ld, struct db *db, struct log *log) struct db *db_setup(const tal_t *ctx, struct lightningd *ld, struct log *log) { - struct db *db = db_open(ctx, DB_FILE); - + struct db *db = db_open(ctx, ld->wallet_dsn); db_migrate(ld, db, log); return db; } @@ -797,13 +807,13 @@ s64 db_get_intvar(struct db *db, char *varname, s64 defval) { s64 res = defval; struct db_stmt *stmt = db_prepare_v2( - db, SQL("SELECT val FROM vars WHERE name= ? LIMIT 1")); + db, SQL("SELECT intval FROM vars WHERE name= ? LIMIT 1")); db_bind_text(stmt, 0, varname); if (!db_query_prepared(stmt)) goto done; if (db_step(stmt)) - res = atol((const char*)db_column_text(stmt, 0)); + res = db_column_int(stmt, 0); done: tal_free(stmt); @@ -812,10 +822,9 @@ s64 db_get_intvar(struct db *db, char *varname, s64 defval) void db_set_intvar(struct db *db, char *varname, s64 val) { - char *v = tal_fmt(NULL, "%"PRIi64, val); size_t changes; - struct db_stmt *stmt = db_prepare_v2(db, SQL("UPDATE vars SET val=? WHERE name=?;")); - db_bind_text(stmt, 0, v); + struct db_stmt *stmt = db_prepare_v2(db, SQL("UPDATE vars SET intval=? WHERE name=?;")); + db_bind_int(stmt, 0, val); db_bind_text(stmt, 1, varname); if (!db_exec_prepared_v2(stmt)) db_fatal("Error executing update: %s", stmt->error); @@ -823,14 +832,13 @@ void db_set_intvar(struct db *db, char *varname, s64 val) tal_free(stmt); if (changes == 0) { - stmt = db_prepare_v2(db, SQL("INSERT INTO vars (name, val) VALUES (?, ?);")); + stmt = db_prepare_v2(db, SQL("INSERT INTO vars (name, intval) VALUES (?, ?);")); db_bind_text(stmt, 0, varname); - db_bind_text(stmt, 1, v); + db_bind_int(stmt, 1, val); if (!db_exec_prepared_v2(stmt)) db_fatal("Error executing insert: %s", stmt->error); tal_free(stmt); } - tal_free(v); } /* Will apply the current config fee settings to all channels */ diff --git a/wallet/db.h b/wallet/db.h index 2444c17723d4..59ec244ab0ad 100644 --- a/wallet/db.h +++ b/wallet/db.h @@ -40,7 +40,7 @@ struct db; * devtools/sql-rewrite.py needs to change as well, since they need to * generate identical names to work correctly. */ -#define SQL(x) NAMED_SQL( __FILE__ ":" stringify(__LINE__) ":" stringify(__COUNTER__), x) +#define SQL(x) NAMED_SQL( __FILE__ ":" stringify(__COUNTER__), x) /** diff --git a/wallet/db_common.h b/wallet/db_common.h index 74b21b3ed834..03ef9a345c7f 100644 --- a/wallet/db_common.h +++ b/wallet/db_common.h @@ -15,7 +15,6 @@ struct db { char *filename; const char *in_transaction; - sqlite3 *sql; /* DB-specific context */ void *conn; @@ -56,7 +55,7 @@ enum db_binding_type { struct db_binding { enum db_binding_type type; union { - int i; + s32 i; u64 u64; const char* text; const u8 *blob; @@ -86,6 +85,8 @@ struct db_stmt { void *inner_stmt; bool executed; + + int row; }; struct db_config { diff --git a/wallet/db_postgres.c b/wallet/db_postgres.c new file mode 100644 index 000000000000..1c465e3aade4 --- /dev/null +++ b/wallet/db_postgres.c @@ -0,0 +1,265 @@ +#include +#include "gen_db_postgres.c" +#include +#include +#include +#include +#include + +#if HAVE_POSTGRES +/* Indented in order not to trigger the inclusion order check */ + #include + +/* Cherry-picked from here: libpq/src/interfaces/ecpg/ecpglib/pg_type.h */ +#define BYTEAOID 17 +#define INT8OID 20 +#define INT4OID 23 +#define TEXTOID 25 + +static bool db_postgres_setup(struct db *db) +{ + db->conn = PQconnectdb(db->filename); + + if (PQstatus(db->conn) != CONNECTION_OK) { + db->error = tal_fmt(db, "Could not connect to %s: %s", db->filename, PQerrorMessage(db->conn)); + db->conn = NULL; + return false; + } + return true; +} + +static bool db_postgres_begin_tx(struct db *db) +{ + assert(db->conn); + PGresult *res; + res = PQexec(db->conn, "BEGIN;"); + if (PQresultStatus(res) != PGRES_COMMAND_OK) { + db->error = tal_fmt(db, "BEGIN command failed: %s", + PQerrorMessage(db->conn)); + PQclear(res); + return false; + } + PQclear(res); + return true; +} + +static bool db_postgres_commit_tx(struct db *db) +{ + assert(db->conn); + PGresult *res; + res = PQexec(db->conn, "COMMIT;"); + if (PQresultStatus(res) != PGRES_COMMAND_OK) { + db->error = tal_fmt(db, "COMMIT command failed: %s", + PQerrorMessage(db->conn)); + PQclear(res); + return false; + } + PQclear(res); + return true; +} + +static PGresult *db_postgres_do_exec(struct db_stmt *stmt) +{ + int slots = stmt->query->placeholders; + const char *paramValues[slots]; + int paramLengths[slots]; + int paramFormats[slots]; + Oid paramTypes[slots]; + int resultFormat = 1; /* We always want binary results. */ + + /* Since we pass in raw pointers to elements converted to network + * byte-order we need a place to temporarily stash them. */ + s32 ints[slots]; + u64 u64s[slots]; + + for (size_t i=0; ibindings[i]; + + switch (b->type) { + case DB_BINDING_UNINITIALIZED: + db_fatal("DB binding not initialized: position=%zu, " + "query=\"%s\n", + i, stmt->query->query); + case DB_BINDING_UINT64: + paramLengths[i] = 8; + paramFormats[i] = 1; + u64s[i] = cpu_to_be64(b->v.u64); + paramValues[i] = (char*)&u64s[i]; + paramTypes[i] = INT8OID; + break; + case DB_BINDING_INT: + paramLengths[i] = 4; + paramFormats[i] = 1; + ints[i] = cpu_to_be32(b->v.i); + paramValues[i] = (char*)&ints[i]; + paramTypes[i] = INT4OID; + break; + case DB_BINDING_BLOB: + paramLengths[i] = b->len; + paramFormats[i] = 1; + paramValues[i] = (char*)b->v.blob; + paramTypes[i] = BYTEAOID; + break; + case DB_BINDING_TEXT: + paramLengths[i] = b->len; + paramFormats[i] = 1; + paramValues[i] = (char*)b->v.text; + paramTypes[i] = TEXTOID; + break; + case DB_BINDING_NULL: + paramLengths[i] = 0; + paramFormats[i] = 1; + paramValues[i] = NULL; + paramTypes[i] = 0; + break; + } + } + return PQexecParams(stmt->db->conn, stmt->query->query, slots, + paramTypes, paramValues, paramLengths, paramFormats, + resultFormat); +} + +static bool db_postgres_query(struct db_stmt *stmt) +{ + stmt->inner_stmt = db_postgres_do_exec(stmt); + int res; + res = PQresultStatus(stmt->inner_stmt); + + if (res != PGRES_EMPTY_QUERY && res != PGRES_TUPLES_OK) { + stmt->error = PQerrorMessage(stmt->db->conn); + PQclear(stmt->inner_stmt); + stmt->inner_stmt = NULL; + return false; + } + stmt->row = -1; + return true; +} + +static bool db_postgres_step(struct db_stmt *stmt) +{ + stmt->row++; + if (stmt->row >= PQntuples(stmt->inner_stmt)) { + return false; + } + return true; +} + +static bool db_postgres_column_is_null(struct db_stmt *stmt, int col) +{ + PGresult *res = (PGresult*)stmt->inner_stmt; + return PQgetisnull(res, stmt->row, col); +} + +static u64 db_postgres_column_u64(struct db_stmt *stmt, int col) +{ + PGresult *res = (PGresult*)stmt->inner_stmt; + be64 bin; + size_t expected = sizeof(bin), actual = PQgetlength(res, stmt->row, col); + + if (expected != actual) + db_fatal( + "u64 field doesn't match size: expected %zu, actual %zu\n", + expected, actual); + + memcpy(&bin, PQgetvalue(res, stmt->row, col), sizeof(bin)); + return be64_to_cpu(bin); +} + +static s64 db_postgres_column_int(struct db_stmt *stmt, int col) +{ + PGresult *res = (PGresult*)stmt->inner_stmt; + be32 bin; + size_t expected = sizeof(bin), actual = PQgetlength(res, stmt->row, col); + + if (expected != actual) + db_fatal( + "s32 field doesn't match size: expected %zu, actual %zu\n", + expected, actual); + + memcpy(&bin, PQgetvalue(res, stmt->row, col), sizeof(bin)); + return be32_to_cpu(bin); +} + +static size_t db_postgres_column_bytes(struct db_stmt *stmt, int col) +{ + PGresult *res = (PGresult *)stmt->inner_stmt; + return PQgetlength(res, stmt->row, col); +} + +static const void *db_postgres_column_blob(struct db_stmt *stmt, int col) +{ + PGresult *res = (PGresult*)stmt->inner_stmt; + return PQgetvalue(res, stmt->row, col); +} + +static const unsigned char *db_postgres_column_text(struct db_stmt *stmt, int col) +{ + PGresult *res = (PGresult*)stmt->inner_stmt; + return (unsigned char*)PQgetvalue(res, stmt->row, col); +} + +static void db_postgres_stmt_free(struct db_stmt *stmt) +{ + if (stmt->inner_stmt) + PQclear(stmt->inner_stmt); + stmt->inner_stmt = NULL; +} + +static bool db_postgres_exec(struct db_stmt *stmt) +{ + bool ok; + stmt->inner_stmt = db_postgres_do_exec(stmt); + ok = PQresultStatus(stmt->inner_stmt) == PGRES_COMMAND_OK; + + if (!ok) + stmt->error = PQerrorMessage(stmt->db->conn); + + return ok; +} + +static u64 db_postgres_last_insert_id(struct db_stmt *stmt) +{ + PGresult *res = PQexec(stmt->db->conn, "SELECT lastval()"); + int id = atoi(PQgetvalue(res, 0, 0)); + PQclear(res); + return id; +} + +static size_t db_postgres_count_changes(struct db_stmt *stmt) +{ + PGresult *res = (PGresult*)stmt->inner_stmt; + char *count = PQcmdTuples(res); + return atoi(count); +} + +static void db_postgres_teardown(struct db *db) +{ +} + +struct db_config db_postgres_config = { + .name = "postgres", + .queries = db_postgres_queries, + .num_queries = DB_POSTGRES_QUERY_COUNT, + .exec_fn = db_postgres_exec, + .query_fn = db_postgres_query, + .step_fn = db_postgres_step, + .begin_tx_fn = &db_postgres_begin_tx, + .commit_tx_fn = &db_postgres_commit_tx, + .stmt_free_fn = db_postgres_stmt_free, + + .column_is_null_fn = db_postgres_column_is_null, + .column_u64_fn = db_postgres_column_u64, + .column_int_fn = db_postgres_column_int, + .column_bytes_fn = db_postgres_column_bytes, + .column_blob_fn = db_postgres_column_blob, + .column_text_fn = db_postgres_column_text, + + .last_insert_id_fn = db_postgres_last_insert_id, + .count_changes_fn = db_postgres_count_changes, + .setup_fn = db_postgres_setup, + .teardown_fn = db_postgres_teardown, +}; + +AUTODATA(db_backends, &db_postgres_config); + +#endif diff --git a/wallet/db_sqlite3.c b/wallet/db_sqlite3.c index 1dd606460a37..d6523aeac78e 100644 --- a/wallet/db_sqlite3.c +++ b/wallet/db_sqlite3.c @@ -24,8 +24,25 @@ static const char *db_sqlite3_fmt_error(struct db_stmt *stmt) static bool db_sqlite3_setup(struct db *db) { + char *filename; sqlite3_stmt *stmt; - int err; + sqlite3 *sql; + int err, flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE; + + if (!strstarts(db->filename, "sqlite3://") || strlen(db->filename) < 10) + db_fatal("Could not parse the wallet DSN: %s", db->filename); + + /* Strip the scheme from the dsn. */ + filename = db->filename + strlen("sqlite3://"); + + err = sqlite3_open_v2(filename, &sql, flags, NULL); + + if (err != SQLITE_OK) { + db_fatal("failed to open database %s: %s", filename, + sqlite3_errstr(err)); + } + db->conn = sql; + sqlite3_prepare_v2(db->conn, "PRAGMA foreign_keys = ON;", -1, &stmt, NULL); err = sqlite3_step(stmt); sqlite3_finalize(stmt); @@ -197,7 +214,8 @@ static size_t db_sqlite3_count_changes(struct db_stmt *stmt) static void db_sqlite3_close(struct db *db) { - sqlite3_close(db->sql); + sqlite3_close(db->conn); + db->conn = NULL; } static u64 db_sqlite3_last_insert_id(struct db_stmt *stmt) diff --git a/wallet/invoices.c b/wallet/invoices.c index 371c8ceed618..7e58361ed9a6 100644 --- a/wallet/invoices.c +++ b/wallet/invoices.c @@ -568,7 +568,7 @@ void invoices_waitany(const tal_t *ctx, stmt = db_prepare_v2(invoices->db, SQL("SELECT id" " FROM invoices" - " WHERE pay_index NOT NULL" + " WHERE pay_index IS NOT NULL" " AND pay_index > ?" " ORDER BY pay_index ASC LIMIT 1;")); db_bind_u64(stmt, 0, lastpay_index); diff --git a/wallet/test/run-db.c b/wallet/test/run-db.c index 1e444eff8672..80a2b8815d5c 100644 --- a/wallet/test/run-db.c +++ b/wallet/test/run-db.c @@ -49,14 +49,16 @@ void plugin_hook_db_sync(struct db *db UNNEEDED, const char **changes UNNEEDED, static struct db *create_test_db(void) { struct db *db; - char filename[] = "/tmp/ldb-XXXXXX"; + char *dsn, filename[] = "/tmp/ldb-XXXXXX"; int fd = mkstemp(filename); if (fd == -1) return NULL; close(fd); - db = db_open(NULL, filename); + dsn = tal_fmt(NULL, "sqlite3://%s", filename); + db = db_open(NULL, dsn); + tal_free(dsn); return db; } diff --git a/wallet/test/run-wallet.c b/wallet/test/run-wallet.c index c47851842784..9b539642c334 100644 --- a/wallet/test/run-wallet.c +++ b/wallet/test/run-wallet.c @@ -727,14 +727,16 @@ static void cleanup_test_wallet(struct wallet *w, char *filename) static struct wallet *create_test_wallet(struct lightningd *ld, const tal_t *ctx) { - char *filename = tal_fmt(ctx, "/tmp/ldb-XXXXXX"); + char *dsn, *filename = tal_fmt(ctx, "/tmp/ldb-XXXXXX"); int fd = mkstemp(filename); struct wallet *w = tal(ctx, struct wallet); static unsigned char badseed[BIP32_ENTROPY_LEN_128]; CHECK_MSG(fd != -1, "Unable to generate temp filename"); close(fd); - w->db = db_open(w, filename); + dsn = tal_fmt(NULL, "sqlite3://%s", filename); + w->db = db_open(w, dsn); + tal_free(dsn); tal_add_destructor2(w, cleanup_test_wallet, filename); list_head_init(&w->unstored_payments); diff --git a/wallet/wallet.c b/wallet/wallet.c index 9d3e11b80e84..eb95f35a40ec 100644 --- a/wallet/wallet.c +++ b/wallet/wallet.c @@ -604,6 +604,7 @@ bool wallet_shachain_add_hash(struct wallet *wallet, struct db_stmt *stmt; u32 pos = count_trailing_zeroes(index); struct sha256 s; + bool updated; BUILD_ASSERT(sizeof(s) == sizeof(*hash)); memcpy(&s, hash, sizeof(s)); @@ -622,13 +623,26 @@ bool wallet_shachain_add_hash(struct wallet *wallet, db_exec_prepared_v2(take(stmt)); stmt = db_prepare_v2(wallet->db, - SQL("REPLACE INTO shachain_known (shachain_id, " - "pos, idx, hash) VALUES (?, ?, ?, ?);")); - db_bind_u64(stmt, 0, chain->id); - db_bind_int(stmt, 1, pos); - db_bind_u64(stmt, 2, index); - db_bind_secret(stmt, 3, hash); - db_exec_prepared_v2(take(stmt)); + SQL("UPDATE shachain_known SET idx=?, hash=? " + "WHERE shachain_id=? AND pos=?")); + db_bind_u64(stmt, 0, index); + db_bind_secret(stmt, 1, hash); + db_bind_u64(stmt, 2, chain->id); + db_bind_int(stmt, 3, pos); + db_exec_prepared_v2(stmt); + updated = db_count_changes(stmt) == 1; + tal_free(stmt); + + if (!updated) { + stmt = db_prepare_v2( + wallet->db, SQL("INSERT INTO shachain_known (shachain_id, " + "pos, idx, hash) VALUES (?, ?, ?, ?);")); + db_bind_u64(stmt, 0, chain->id); + db_bind_int(stmt, 1, pos); + db_bind_u64(stmt, 2, index); + db_bind_secret(stmt, 3, hash); + db_exec_prepared_v2(take(stmt)); + } return true; } @@ -1104,12 +1118,12 @@ void wallet_channel_stats_load(struct wallet *w, /* This must succeed, since we know the channel exists */ assert(res); - stats->in_payments_offered = db_column_u64(stmt, 0); - stats->in_payments_fulfilled = db_column_u64(stmt, 1); + stats->in_payments_offered = db_column_int(stmt, 0); + stats->in_payments_fulfilled = db_column_int(stmt, 1); db_column_amount_msat(stmt, 2, &stats->in_msatoshi_offered); db_column_amount_msat(stmt, 3, &stats->in_msatoshi_fulfilled); - stats->out_payments_offered = db_column_u64(stmt, 4); - stats->out_payments_fulfilled = db_column_u64(stmt, 5); + stats->out_payments_offered = db_column_int(stmt, 4); + stats->out_payments_fulfilled = db_column_int(stmt, 5); db_column_amount_msat(stmt, 6, &stats->out_msatoshi_offered); db_column_amount_msat(stmt, 7, &stats->out_msatoshi_fulfilled); tal_free(stmt); @@ -1354,7 +1368,8 @@ void wallet_channel_save(struct wallet *w, struct channel *chan) else db_bind_null(stmt, 0); db_bind_u64(stmt, 1, chan->dbid); - db_exec_prepared_v2(take(stmt)); + db_exec_prepared_v2(stmt); + tal_free(stmt); } void wallet_channel_insert(struct wallet *w, struct channel *chan) @@ -1979,7 +1994,7 @@ struct htlc_stub *wallet_htlc_stubs(const tal_t *ctx, struct wallet *wallet, /* FIXME: merge these two enums */ stub.owner = db_column_int(stmt, 1)==DIRECTION_INCOMING?REMOTE:LOCAL; stub.cltv_expiry = db_column_int(stmt, 2); - stub.id = db_column_int(stmt, 3); + stub.id = db_column_u64(stmt, 3); db_column_sha256(stmt, 4, &payment_hash); ripemd160(&stub.ripemd, payment_hash.u.u8, sizeof(payment_hash.u)); @@ -2162,13 +2177,13 @@ static struct wallet_payment *wallet_stmt2payment(const tal_t *ctx, db_column_amount_msat(stmt, 10, &payment->msatoshi_sent); - if (!db_column_is_null(stmt, 11)) + if (!db_column_is_null(stmt, 11) && db_column_text(stmt, 11) != NULL) payment->label = tal_strdup(payment, (const char *)db_column_text(stmt, 11)); else payment->label = NULL; - if (!db_column_is_null(stmt, 12)) + if (!db_column_is_null(stmt, 12) && db_column_text(stmt, 12) != NULL) payment->bolt11 = tal_strdup( payment, (const char *)db_column_text(stmt, 12)); else @@ -2483,7 +2498,7 @@ bool wallet_network_check(struct wallet *w, { struct bitcoin_blkid chainhash; struct db_stmt *stmt = db_prepare_v2( - w->db, SQL("SELECT val FROM vars WHERE name='genesis_hash'")); + w->db, SQL("SELECT blobval FROM vars WHERE name='genesis_hash'")); db_query_prepared(stmt); if (db_step(stmt)) { @@ -2507,7 +2522,7 @@ bool wallet_network_check(struct wallet *w, tal_free(stmt); /* Still a pristine wallet, claim it for the chain * that we are running */ - stmt = db_prepare_v2(w->db, SQL("INSERT INTO vars (name, val) " + stmt = db_prepare_v2(w->db, SQL("INSERT INTO vars (name, blobval) " "VALUES ('genesis_hash', ?);")); db_bind_sha256d(stmt, 0, &chainparams->genesis_blockhash.shad); db_exec_prepared_v2(take(stmt)); @@ -2684,7 +2699,8 @@ void wallet_filteredblock_add(struct wallet *w, const struct filteredblock *fb) struct db_stmt *stmt; if (wallet_have_block(w, fb->height)) return; - stmt = db_prepare_v2(w->db, SQL("INSERT OR IGNORE INTO blocks " + + stmt = db_prepare_v2(w->db, SQL("INSERT INTO blocks " "(height, hash, prev_hash) " "VALUES (?, ?, ?);")); db_bind_int(stmt, 0, fb->height); @@ -2760,7 +2776,10 @@ struct outpoint *wallet_outpoint_for_scid(struct wallet *w, tal_t *ctx, op->txindex = short_channel_id_txnum(scid); op->outnum = short_channel_id_outnum(scid); db_column_sha256d(stmt, 0, &op->txid.shad); - op->spendheight = db_column_int(stmt, 1); + if (db_column_is_null(stmt, 1)) + op->spendheight = 0; + else + op->spendheight = db_column_int(stmt, 1); op->scriptpubkey = tal_arr(op, u8, db_column_bytes(stmt, 2)); memcpy(op->scriptpubkey, db_column_blob(stmt, 2), db_column_bytes(stmt, 2)); db_column_amount_sat(stmt, 3, &op->sat); @@ -2829,8 +2848,10 @@ void wallet_transaction_annotate(struct wallet *w, fatal("Attempting to annotate a transaction we don't have: %s", type_to_string(tmpctx, struct bitcoin_txid, txid)); - type |= db_column_int(stmt, 0); - if (channel_id == 0) + if (!db_column_is_null(stmt, 0)) + type |= db_column_u64(stmt, 0); + + if (channel_id == 0 && !db_column_is_null(stmt, 1)) channel_id = db_column_u64(stmt, 1); tal_free(stmt); @@ -2840,7 +2861,7 @@ void wallet_transaction_annotate(struct wallet *w, ", channel_id = ? " "WHERE id = ?")); - db_bind_int(stmt, 0, type); + db_bind_u64(stmt, 0, type); if (channel_id) db_bind_int(stmt, 1, channel_id); @@ -2863,7 +2884,7 @@ bool wallet_transaction_type(struct wallet *w, const struct bitcoin_txid *txid, return false; } - *type = db_column_int(stmt, 0); + *type = db_column_u64(stmt, 0); tal_free(stmt); return true; } @@ -2881,7 +2902,10 @@ u32 wallet_transaction_height(struct wallet *w, const struct bitcoin_txid *txid) return 0; } - blockheight = db_column_int(stmt, 0); + if (!db_column_is_null(stmt, 0)) + blockheight = db_column_int(stmt, 0); + else + blockheight = 0; tal_free(stmt); return blockheight; } @@ -2993,7 +3017,7 @@ struct channeltx *wallet_channeltxs_get(struct wallet *w, const tal_t *ctx, ", c.blockheight - t.blockheight + 1 AS depth" ", t.id as txid " "FROM channeltxs c " - "JOIN transactions t ON t.id == c.transaction_id " + "JOIN transactions t ON t.id = c.transaction_id " "WHERE c.channel_id = ? " "ORDER BY c.id ASC;")); db_bind_int(stmt, 0, channel_id); @@ -3015,6 +3039,59 @@ struct channeltx *wallet_channeltxs_get(struct wallet *w, const tal_t *ctx, return res; } +static bool wallet_forwarded_payment_update(struct wallet *w, + const struct htlc_in *in, + const struct htlc_out *out, + enum forward_status state, + enum onion_type failcode, + struct timeabs *resolved_time) +{ + struct db_stmt *stmt; + bool changed; + + /* We update based solely on the htlc_in since an HTLC cannot be + * associated with more than one forwarded payment. This saves us from + * having to have two versions of the update statement (one with and + * one without the htlc_out restriction).*/ + stmt = db_prepare_v2(w->db, + SQL("UPDATE forwarded_payments SET" + " in_msatoshi=?" + ", out_msatoshi=?" + ", state=?" + ", resolved_time=?" + ", failcode=?" + " WHERE in_htlc_id=?")); + db_bind_amount_msat(stmt, 0, &in->msat); + + if (out) { + db_bind_amount_msat(stmt, 1, &out->msat); + } else { + db_bind_null(stmt, 1); + } + + db_bind_int(stmt, 2, wallet_forward_status_in_db(state)); + + if (resolved_time != NULL) { + db_bind_timeabs(stmt, 3, *resolved_time); + } else { + db_bind_null(stmt, 3); + } + + if(failcode != 0) { + assert(state == FORWARD_FAILED || state == FORWARD_LOCAL_FAILED); + db_bind_int(stmt, 4, (int)failcode); + } else { + db_bind_null(stmt, 4); + } + + db_bind_u64(stmt, 5, in->dbid); + db_exec_prepared_v2(stmt); + changed = db_count_changes(stmt) != 0; + tal_free(stmt); + + return changed; +} + void wallet_forwarded_payment_add(struct wallet *w, const struct htlc_in *in, const struct htlc_out *out, enum forward_status state, @@ -3022,8 +3099,19 @@ void wallet_forwarded_payment_add(struct wallet *w, const struct htlc_in *in, { struct db_stmt *stmt; struct timeabs *resolved_time; + + if (state == FORWARD_SETTLED || state == FORWARD_FAILED) { + resolved_time = tal(tmpctx, struct timeabs); + *resolved_time = time_now(); + } else { + resolved_time = NULL; + } + + if (wallet_forwarded_payment_update(w, in, out, state, failcode, resolved_time)) + goto notify; + stmt = db_prepare_v2(w->db, - SQL("INSERT OR REPLACE INTO forwarded_payments (" + SQL("INSERT INTO forwarded_payments (" " in_htlc_id" ", out_htlc_id" ", in_channel_scid" @@ -3057,14 +3145,10 @@ void wallet_forwarded_payment_add(struct wallet *w, const struct htlc_in *in, db_bind_int(stmt, 6, wallet_forward_status_in_db(state)); db_bind_timeabs(stmt, 7, in->received_time); - if (state == FORWARD_SETTLED || state == FORWARD_FAILED) { - resolved_time = tal(tmpctx, struct timeabs); - *resolved_time = time_now(); + if (resolved_time != NULL) db_bind_timeabs(stmt, 8, *resolved_time); - } else { - resolved_time = NULL; + else db_bind_null(stmt, 8); - } if(failcode != 0) { assert(state == FORWARD_FAILED || state == FORWARD_LOCAL_FAILED); @@ -3075,6 +3159,7 @@ void wallet_forwarded_payment_add(struct wallet *w, const struct htlc_in *in, db_exec_prepared_v2(take(stmt)); +notify: notify_forward_event(w->ld, in, out, state, failcode, resolved_time); } @@ -3085,7 +3170,7 @@ struct amount_msat wallet_total_forward_fees(struct wallet *w) bool res; stmt = db_prepare_v2(w->db, SQL("SELECT" - " SUM(in_msatoshi - out_msatoshi) " + " CAST(COALESCE(SUM(in_msatoshi - out_msatoshi), 0) AS BIGINT)" "FROM forwarded_payments " "WHERE state = ?;")); db_bind_int(stmt, 0, wallet_forward_status_in_db(FORWARD_SETTLED)); @@ -3119,7 +3204,7 @@ const struct forwarding *wallet_forwarded_payments_get(struct wallet *w, ", f.resolved_time" ", f.failcode " "FROM forwarded_payments f " - "LEFT JOIN channel_htlcs hin ON (f.in_htlc_id == hin.id)")); + "LEFT JOIN channel_htlcs hin ON (f.in_htlc_id = hin.id)")); db_query_prepared(stmt); for (count=0; db_step(stmt); count++) {